feat: MultiThreaded stream support

This commit is contained in:
divyam234 2024-06-22 17:59:59 +05:30
parent 8c9ca99a93
commit 825dc11fe1
27 changed files with 984 additions and 788 deletions

36
.github/workflows/build-dev.yml vendored Normal file
View file

@ -0,0 +1,36 @@
name: Build Dev
on:
push:
branches:
- dev
permissions: write-all
jobs:
goreleaser:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set Vars
run: |
echo "IMAGE=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
- name: Build Image
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: ${{ env.IMAGE }}:dev

24
Dockerfile Normal file
View file

@ -0,0 +1,24 @@
FROM golang:alpine AS builder
RUN apk add --no-cache git unzip curl make bash
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN make build
FROM scratch
WORKDIR /
COPY --from=builder /app/bin/teldrive /teldrive
EXPOSE 8080
ENTRYPOINT ["/teldrive","run","--tg-session-file","/session.db"]

View file

@ -11,19 +11,16 @@ FRONTEND_ASSET := https://github.com/divyam234/teldrive-ui/releases/download/v1/
GIT_TAG := $(shell git describe --tags --abbrev=0)
GIT_COMMIT := $(shell git rev-parse --short HEAD)
GIT_LINK := $(shell git remote get-url origin)
ENV_FILE := $(FRONTEND_DIR)/.env
MODULE_PATH := $(shell go list -m)
BUILD_DATE := $(shell $(BUILD_DATE))
GOOS ?= $(shell go env GOOS)
GOARCH ?= $(shell go env GOARCH)
VERSION:= $(GIT_TAG)
BINARY_EXTENSION :=
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version
all: build
frontend:
@echo "Extract UI"
ifeq ($(OS),Windows_NT)
@ -40,10 +37,13 @@ else
rm -rf teldrive-ui.zip
endif
ifeq ($(OS),Windows_NT)
BINARY_EXTENSION := .exe
endif
backend:
@echo "Building backend for $(GOOS)/$(GOARCH)..."
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(GIT_TAG) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(VERSION) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION)
build: frontend backend
@echo "Building complete."

View file

@ -28,6 +28,8 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi
files.PATCH(":fileID", authmiddleware, c.UpdateFile)
files.HEAD(":fileID/stream/:fileName", c.GetFileStream)
files.GET(":fileID/stream/:fileName", c.GetFileStream)
files.HEAD(":fileID/download/:fileName", c.GetFileDownload)
files.GET(":fileID/download/:fileName", c.GetFileDownload)
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
files.POST("/move", authmiddleware, c.MoveFiles)

View file

@ -89,12 +89,11 @@ func NewRun() *cobra.Command {
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
runCmd.Flags().IntVar(&config.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries")
runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size")
runCmd.Flags().BoolVar(&config.TG.Stream.BufferReader, "tg-stream-buffer-reader", false, "Async Buffered reader for fast streaming")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 16, "No of Stream buffers")
runCmd.Flags().BoolVar(&config.TG.Stream.UseMmap, "tg-stream-use-mmap", false, "Use mmap for stream buffers")
runCmd.Flags().BoolVar(&config.TG.Stream.UsePooling, "tg-stream-use-pooling", false, "Use session pooling for stream workers")
duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
runCmd.MarkFlagRequired("tg-app-id")
runCmd.MarkFlagRequired("tg-app-hash")
runCmd.MarkFlagRequired("db-data-source")
@ -162,11 +161,11 @@ func initViperConfig(cmd *cobra.Command) error {
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv()
viper.ReadInConfig()
bindFlagsRecursive(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
bindFlags(cmd.Flags(), "", reflect.ValueOf(config.Config{}))
return nil
}
func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
func bindFlags(flags *pflag.FlagSet, prefix string, v reflect.Value) {
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
@ -175,7 +174,7 @@ func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) {
field := t.Field(i)
switch field.Type.Kind() {
case reflect.Struct:
bindFlagsRecursive(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
bindFlags(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i))
default:
newPrefix := prefix[1:]
newName := modifyFlag(field.Name)

View file

@ -43,8 +43,6 @@
retention = "7d"
threads = 8
[tg.stream]
buffer-reader = false
buffers = 6
use-mmap = false
use-pooling= false
multi-threads = 0
buffers = 16

3
go.mod
View file

@ -16,7 +16,6 @@ require (
github.com/magiconair/properties v1.8.7
github.com/mitchellh/go-homedir v1.1.0
github.com/pkg/errors v0.9.1
github.com/rclone/rclone v1.67.0
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
@ -97,7 +96,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pressly/goose/v3 v3.20.0
github.com/pressly/goose/v3 v3.21.1
github.com/segmentio/asm v1.2.0 // indirect
github.com/stretchr/testify v1.9.0
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect

10
go.sum
View file

@ -145,8 +145,8 @@ github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbW
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/microsoft/go-mssqldb v1.7.0 h1:sgMPW0HA6Ihd37Yx0MzHyKD726C2kY/8KJsQtXHNaAs=
github.com/microsoft/go-mssqldb v1.7.0/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM=
github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
@ -166,10 +166,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose/v3 v3.20.0 h1:uPJdOxF/Ipj7ABVNOAMJXSxwFXZGwMGHNqjC8e61VA0=
github.com/pressly/goose/v3 v3.20.0/go.mod h1:BRfF2GcG4FTG12QfdBVy3q1yveaf4ckL9vWwEcIO3lA=
github.com/rclone/rclone v1.67.0 h1:yLRNgHEG2vQ60HCuzFqd0hYwKCRuWuvPUhvhMJ2jI5E=
github.com/rclone/rclone v1.67.0/go.mod h1:Cb3Ar47M/SvwfhAjZTbVXdtrP/JLtPFCq2tkdtBVC6w=
github.com/pressly/goose/v3 v3.21.1 h1:5SSAKKWej8LVVzNLuT6KIvP1eFDuPvxa+B6H0w78buQ=
github.com/pressly/goose/v3 v3.21.1/go.mod h1:sqthmzV8PitchEkjecFJII//l43dLOCzfWh8pHEe+vE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=

View file

@ -2,8 +2,10 @@ package auth
import (
"encoding/json"
"strconv"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3"
)
@ -59,3 +61,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
return jwtToken, nil
}
func GetUser(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}

View file

@ -43,10 +43,9 @@ type TGConfig struct {
Retention time.Duration
}
Stream struct {
BufferReader bool
MultiThreads int
Buffers int
UseMmap bool
UsePooling bool
ChunkTimeout time.Duration
}
}

View file

@ -1,206 +0,0 @@
// taken from rclone async reader implmentation
package reader
import (
"context"
"errors"
"io"
"sync"
"time"
"github.com/rclone/rclone/lib/pool"
)
const (
BufferSize = 1024 * 1024
softStartInitial = 4 * 1024
bufferCacheSize = 64
bufferCacheFlushTime = 5 * time.Second
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type AsyncReader struct {
in io.ReadCloser
ready chan *buffer
token chan struct{}
exit chan struct{}
buffers int
err error
cur *buffer
exited chan struct{}
closed bool
mu sync.Mutex
}
func NewAsyncReader(ctx context.Context, rd io.ReadCloser, buffers int) (*AsyncReader, error) {
if buffers <= 0 {
return nil, errors.New("number of buffers too small")
}
if rd == nil {
return nil, errors.New("nil reader supplied")
}
a := &AsyncReader{}
a.init(rd, buffers)
return a, nil
}
func (a *AsyncReader) init(rd io.ReadCloser, buffers int) {
a.in = rd
a.ready = make(chan *buffer, buffers)
a.token = make(chan struct{}, buffers)
a.exit = make(chan struct{})
a.exited = make(chan struct{})
a.buffers = buffers
a.cur = nil
for i := 0; i < buffers; i++ {
a.token <- struct{}{}
}
go func() {
defer close(a.exited)
defer close(a.ready)
for {
select {
case <-a.token:
b := a.getBuffer()
err := b.read(a.in)
a.ready <- b
if err != nil {
return
}
case <-a.exit:
return
}
}
}()
}
var bufferPool *pool.Pool
var bufferPoolOnce sync.Once
func (a *AsyncReader) putBuffer(b *buffer) {
bufferPool.Put(b.buf)
b.buf = nil
}
func (a *AsyncReader) getBuffer() *buffer {
bufferPoolOnce.Do(func() {
bufferPool = pool.New(bufferCacheFlushTime, BufferSize, bufferCacheSize, false)
})
return &buffer{
buf: bufferPool.Get(),
}
}
func (a *AsyncReader) fill() (err error) {
if a.cur.isEmpty() {
if a.cur != nil {
a.putBuffer(a.cur)
a.token <- struct{}{}
a.cur = nil
}
b, ok := <-a.ready
if !ok {
if a.err == nil {
return ErrorStreamAbandoned
}
return a.err
}
a.cur = b
}
return nil
}
func (a *AsyncReader) Read(p []byte) (n int, err error) {
a.mu.Lock()
defer a.mu.Unlock()
err = a.fill()
if err != nil {
return 0, err
}
n = copy(p, a.cur.buffer())
a.cur.increment(n)
if a.cur.isEmpty() {
a.err = a.cur.err
return n, a.err
}
return n, nil
}
func (a *AsyncReader) StopBuffering() {
select {
case <-a.exit:
return
default:
}
close(a.exit)
<-a.exited
}
func (a *AsyncReader) Abandon() {
a.StopBuffering()
a.mu.Lock()
defer a.mu.Unlock()
if a.cur != nil {
a.putBuffer(a.cur)
a.cur = nil
}
for b := range a.ready {
a.putBuffer(b)
}
}
func (a *AsyncReader) Close() (err error) {
a.Abandon()
if a.closed {
return nil
}
a.closed = true
return a.in.Close()
}
type buffer struct {
buf []byte
err error
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) readFill(r io.Reader, buf []byte) (n int, err error) {
var nn int
for n < len(buf) && err == nil {
nn, err = r.Read(buf[n:])
n += nn
}
return n, err
}
func (b *buffer) read(rd io.Reader) error {
var n int
n, b.err = b.readFill(rd, b.buf)
b.buf = b.buf[0:n]
b.offset = 0
return b.err
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}

View file

@ -6,36 +6,47 @@ import (
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/tg"
)
type decrpytedReader struct {
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
client *tg.Client
reader io.ReadCloser
limit int64
err error
config *config.TGConfig
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
reader io.ReadCloser
limit int64
config *config.TGConfig
channelId int64
worker *tgc.StreamWorker
client *tgc.Client
fileId string
concurrency int
}
func NewDecryptedReader(
ctx context.Context,
client *tg.Client,
fileId string,
parts []types.Part,
start, end int64,
config *config.TGConfig) (io.ReadCloser, error) {
channelId int64,
config *config.TGConfig,
concurrency int,
client *tgc.Client,
worker *tgc.StreamWorker) (io.ReadCloser, error) {
r := &decrpytedReader{
ctx: ctx,
parts: parts,
client: client,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config,
ctx: ctx,
parts: parts,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config,
client: client,
worker: worker,
channelId: channelId,
fileId: fileId,
concurrency: concurrency,
}
res, err := r.nextPart()
@ -51,30 +62,24 @@ func NewDecryptedReader(
func (r *decrpytedReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
if r.limit <= 0 {
return 0, io.EOF
}
n, err = r.reader.Read(p)
if err == nil {
r.limit -= int64(n)
}
r.limit -= int64(n)
if err == io.EOF {
if r.limit > 0 {
err = nil
if r.reader != nil {
r.reader.Close()
}
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
}
}
r.err = err
return
}
@ -89,7 +94,6 @@ func (r *decrpytedReader) Close() (err error) {
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
location := r.parts[r.ranges[r.pos].PartNo].Location
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
salt := r.parts[r.ranges[r.pos].PartNo].Salt
@ -99,21 +103,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
func(ctx context.Context,
underlyingOffset,
underlyingLimit int64) (io.ReadCloser, error) {
var end int64
if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
rd, err := newTGReader(r.ctx, r.client, location, underlyingOffset, end)
if err != nil {
return nil, err
}
if r.config.Stream.BufferReader {
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
}
return rd, nil
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
}, start, end-start+1)

View file

@ -5,8 +5,8 @@ import (
"io"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/tg"
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
@ -37,31 +37,42 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
}
type linearReader struct {
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
client *tg.Client
reader io.ReadCloser
limit int64
err error
config *config.TGConfig
ctx context.Context
parts []types.Part
ranges []types.Range
pos int
reader io.ReadCloser
limit int64
config *config.TGConfig
channelId int64
worker *tgc.StreamWorker
client *tgc.Client
fileId string
concurrency int
}
func NewLinearReader(ctx context.Context,
client *tg.Client,
fileId string,
parts []types.Part,
start, end int64,
channelId int64,
config *config.TGConfig,
concurrency int,
client *tgc.Client,
worker *tgc.StreamWorker,
) (reader io.ReadCloser, err error) {
r := &linearReader{
ctx: ctx,
parts: parts,
client: client,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config,
ctx: ctx,
parts: parts,
limit: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config,
client: client,
worker: worker,
channelId: channelId,
fileId: fileId,
concurrency: concurrency,
}
r.reader, err = r.nextPart()
@ -73,25 +84,22 @@ func NewLinearReader(ctx context.Context,
return r, nil
}
func (r *linearReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
func (r *linearReader) Read(p []byte) (int, error) {
if r.limit <= 0 {
return 0, io.EOF
}
n, err = r.reader.Read(p)
n, err := r.reader.Read(p)
if err == nil {
r.limit -= int64(n)
}
r.limit -= int64(n)
if err == io.EOF {
if r.limit > 0 {
err = nil
if r.reader != nil {
r.reader.Close()
}
}
r.pos++
if r.pos < len(r.ranges) {
@ -99,24 +107,18 @@ func (r *linearReader) Read(p []byte) (n int, err error) {
}
}
r.err = err
return
return n, err
}
func (r *linearReader) nextPart() (io.ReadCloser, error) {
location := r.parts[r.ranges[r.pos].PartNo].Location
startByte := r.ranges[r.pos].Start
endByte := r.ranges[r.pos].End
rd, err := newTGReader(r.ctx, r.client, location, startByte, endByte)
if err != nil {
return nil, err
}
if r.config.Stream.BufferReader {
return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers)
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
}
return rd, nil
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
}

View file

@ -0,0 +1,311 @@
package reader
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgReader struct {
ctx context.Context
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
}
func newTGReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
r := &tgReader{
ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
}
if r.concurrency == 0 {
r.currentPart = 1
go r.fillBufferSequentially()
} else {
go r.fillBufferConcurrently()
}
return r, nil
}
func (r *tgReader) Close() error {
close(r.done)
close(r.err)
return nil
}
func (r *tgReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
n := copy(p, r.cur.buffer())
r.cur.increment(n)
r.limit -= int64(n)
if r.limit <= 0 {
return n, io.EOF
}
return n, nil
}
func (r *tgReader) fillBufferConcurrently() error {
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
close(r.bufferChan)
r.closed = true
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
r.bufferChan <- bufferMap[i]
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
}
func (r *tgReader) fillBufferSequentially() error {
defer close(r.bufferChan)
fetchChunk := func(ctx context.Context) (*buffer, error) {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize)
if err != nil {
return nil, err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart == r.totalParts {
chunk = chunk[:r.rightCut]
}
return &buffer{buf: chunk}, nil
}
for {
select {
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
case <-time.After(r.timeout):
return nil
default:
buf, err := fetchChunk(r.ctx)
if err != nil {
r.err <- err
return nil
}
r.bufferChan <- buf
r.currentPart++
r.offset += r.chunkSize
if r.currentPart > r.totalParts {
return nil
}
}
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}

View file

@ -0,0 +1,142 @@
package reader
import (
"context"
"crypto/rand"
"io"
"testing"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type testChunkSource struct {
buffer []byte
}
func (m *testChunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
return m.buffer[offset : offset+limit], nil
}
func (m *testChunkSource) ChunkSize(start, end int64) int64 {
return 1
}
type testChunkSourceTimeout struct {
buffer []byte
}
func (m *testChunkSourceTimeout) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
if offset == 8 {
time.Sleep(2 * time.Second)
}
return m.buffer[offset : offset+limit], nil
}
func (m *testChunkSourceTimeout) ChunkSize(start, end int64) int64 {
return 1
}
type TestSuite struct {
suite.Suite
config *config.TGConfig
}
func (suite *TestSuite) SetupTest() {
suite.config = &config.TGConfig{Stream: struct {
MultiThreads int
Buffers int
ChunkTimeout time.Duration
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
}
func (suite *TestSuite) TestFullRead() {
ctx := context.Background()
start := int64(0)
end := int64(99)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.Equal(suite.T(), nil, err)
assert.Equal(suite.T(), data[start:end+1], test_data)
}
func (suite *TestSuite) TestPartialRead() {
ctx := context.Background()
start := int64(0)
end := int64(65)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), data[start:end+1], test_data)
}
func (suite *TestSuite) TestTimeout() {
ctx := context.Background()
start := int64(0)
end := int64(65)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.Greater(suite.T(), len(test_data), 0)
assert.Equal(suite.T(), err, ErrorStreamAbandoned)
}
func (suite *TestSuite) TestClose() {
ctx := context.Background()
start := int64(0)
end := int64(65)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader)
assert.NoError(suite.T(), err)
assert.NoError(suite.T(), reader.Close())
}
func (suite *TestSuite) TestCancellation() {
ctx, cancel := context.WithCancel(context.Background())
start := int64(0)
end := int64(65)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
cancel()
_, err = io.ReadAll(reader)
assert.Equal(suite.T(), err, context.Canceled)
assert.Equal(suite.T(), len(reader.bufferChan), 0)
}
func (suite *TestSuite) TestCancellationWithTimeout() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
_ = cancel
start := int64(0)
end := int64(65)
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader)
assert.Equal(suite.T(), err, context.DeadlineExceeded)
assert.Equal(suite.T(), len(reader.bufferChan), 0)
}
func Test(t *testing.T) {
suite.Run(t, new(TestSuite))
}

View file

@ -1,144 +0,0 @@
package reader
import (
"context"
"fmt"
"io"
"github.com/gotd/td/tg"
)
type tgReader struct {
ctx context.Context
client *tg.Client
location *tg.InputDocumentFileLocation
start int64
end int64
next func() ([]byte, error)
buffer []byte
limit int64
chunkSize int64
i int64
}
func calculateChunkSize(start, end int64) int64 {
chunkSize := int64(1024 * 1024)
for chunkSize > 1024 && chunkSize > (end-start) {
chunkSize /= 2
}
return chunkSize
}
func newTGReader(
ctx context.Context,
client *tg.Client,
location *tg.InputDocumentFileLocation,
start int64,
end int64,
) (io.ReadCloser, error) {
r := &tgReader{
ctx: ctx,
location: location,
client: client,
start: start,
end: end,
chunkSize: calculateChunkSize(start, end),
limit: end - start + 1,
}
r.next = r.partStream()
return r, nil
}
func (r *tgReader) Read(p []byte) (n int, err error) {
if r.limit <= 0 {
return 0, io.EOF
}
if r.i >= int64(len(r.buffer)) {
r.buffer, err = r.next()
if err != nil {
return 0, err
}
if len(r.buffer) == 0 {
r.next = r.partStream()
r.buffer, err = r.next()
if err != nil {
return 0, err
}
}
r.i = 0
}
n = copy(p, r.buffer[r.i:])
r.i += int64(n)
r.limit -= int64(n)
return
}
func (*tgReader) Close() error {
return nil
}
func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: r.location,
Precise: true,
}
res, err := r.client.UploadGetFile(r.ctx, req)
if err != nil {
return nil, err
}
switch result := res.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func (r *tgReader) partStream() func() ([]byte, error) {
start := r.start
end := r.end
offset := start - (start % r.chunkSize)
leftCut := start - offset
rightCut := (end % r.chunkSize) + 1
totalParts := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
return func() ([]byte, error) {
if currentPart > totalParts {
return make([]byte, 0), nil
}
res, err := r.chunk(offset, r.chunkSize)
if err != nil {
return nil, err
}
if len(res) == 0 {
return res, nil
} else if totalParts == 1 {
res = res[leftCut:rightCut]
} else if currentPart == 1 {
res = res[leftCut:]
} else if currentPart == totalParts {
res = res[:rightCut]
}
currentPart++
offset += r.chunkSize
return res, nil
}
}

239
internal/tgc/helpers.go Normal file
View file

@ -0,0 +1,239 @@
package tgc
import (
"bytes"
"context"
"errors"
"fmt"
"math"
"runtime"
"sync"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var (
ErrInValidChannelID = errors.New("invalid channel id")
ErrInvalidChannelMessages = errors.New("invalid channel messages")
)
func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*tg.InputChannel, error) {
inputChannel := &tg.InputChannel{
ChannelID: channelId,
}
channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
if err != nil {
return nil, err
}
if len(channels.GetChats()) == 0 {
return nil, ErrInValidChannelID
}
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
}
func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int64, ids []int) error {
return RunWithAuth(ctx, client, "", func(ctx context.Context) error {
channel, err := GetChannelById(ctx, client.API(), channelId)
if err != nil {
return err
}
batchSize := 100
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
g, _ := errgroup.WithContext(ctx)
g.SetLimit(runtime.NumCPU())
for i := 0; i < batchCount; i++ {
start := i * batchSize
end := min((i+1)*batchSize, len(ids))
batchIds := ids[start:end]
g.Go(func() error {
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
return err
})
}
return g.Wait()
})
}
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
msgIds := []tg.InputMessageClass{}
for _, id := range ids {
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
}
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: msgIds,
}
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
return res, nil
}
func GetMessages(ctx context.Context, client *tg.Client, ids []int, channelId int64) ([]tg.MessageClass, error) {
channel, err := GetChannelById(ctx, client, channelId)
if err != nil {
return nil, err
}
batchSize := 200
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
g, _ := errgroup.WithContext(ctx)
g.SetLimit(runtime.NumCPU())
messageMap := make(map[int]*tg.MessagesChannelMessages)
var mapMu sync.Mutex
for i := range batchCount {
g.Go(func() error {
splitIds := ids[i*batchSize : min((i+1)*batchSize, len(ids))]
res, err := getTGMessagesBatch(ctx, client, channel, splitIds)
if err != nil {
return err
}
messages, ok := res.(*tg.MessagesChannelMessages)
if !ok {
return ErrInvalidChannelMessages
}
mapMu.Lock()
messageMap[i] = messages
mapMu.Unlock()
return nil
})
}
if err = g.Wait(); err != nil {
return nil, err
}
allMessages := []tg.MessageClass{}
for i := range batchCount {
allMessages = append(allMessages, messageMap[i].Messages...)
}
return allMessages, nil
}
func GetChunk(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: location,
Precise: true,
}
r, err := client.UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
offset := int64(0)
limit := int64(1024 * 1024)
buff := &bytes.Buffer{}
for {
r, err := GetChunk(ctx, client, location, offset, limit)
if err != nil {
return buff, err
}
if len(r) == 0 {
break
}
buff.Write(r)
offset += int64(limit)
}
return buff, nil
}
func GetBotInfo(ctx context.Context, client *telegram.Client, token string) (*types.BotInfo, error) {
var user *tg.User
err := RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil
})
if err != nil {
return nil, err
}
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
}
func GetLocation(ctx context.Context, client *Client, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) {
cache := cache.FromContext(ctx)
key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId)
err = cache.Get(key, location)
if err != nil {
channel, err := GetChannelById(ctx, client.Tg.API(), channelId)
if err != nil {
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}},
}
res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
return nil, err
}
messages, _ := res.(*tg.MessagesChannelMessages)
item := messages.Messages[0].(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location = document.AsInputDocumentFileLocation()
cache.Set(key, location, 3600)
}
return location, nil
}
func CalculateChunkSize(start, end int64) int64 {
chunkSize := int64(1024 * 1024)
for chunkSize > 1024 && chunkSize > (end-start) {
chunkSize /= 2
}
return chunkSize
}

View file

@ -55,8 +55,8 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa
LangCode: config.LangCode,
},
SessionStorage: storage,
RetryInterval: 5 * time.Second,
MaxRetries: 5,
RetryInterval: 2 * time.Second,
MaxRetries: 10,
DialTimeout: 10 * time.Second,
Middlewares: middlewares,
UpdateHandler: handler,

View file

@ -2,11 +2,11 @@ package tgc
import (
"context"
"strings"
"sync"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/pool"
"github.com/gotd/td/telegram"
)
@ -42,9 +42,9 @@ func NewUploadWorker() *UploadWorker {
type Client struct {
Tg *telegram.Client
Pool pool.Pool
Stop StopFunc
Status string
UserId string
}
type StreamWorker struct {
@ -69,10 +69,7 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
for _, token := range bots {
middlewares := Middlewares(w.cnf, 5)
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
c := &Client{Tg: client, Status: "idle"}
if w.cnf.Stream.UsePooling {
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
}
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
w.clients[channelId] = append(w.clients[channelId], c)
}
w.currIdx[channelId] = 0
@ -108,9 +105,6 @@ func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error)
middlewares := Middlewares(w.cnf, 5)
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
c := &Client{Tg: client, Status: "idle"}
if w.cnf.Stream.UsePooling {
c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...)
}
w.clients[userId] = append(w.clients[userId], c)
}
nextClient := w.clients[userId][0]

View file

@ -3,11 +3,11 @@ package controller
import (
"net/http"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/logging"
"github.com/divyam234/teldrive/pkg/httputil"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/services"
"github.com/gin-gonic/gin"
)
@ -23,7 +23,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
return
}
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
if err != nil {
@ -36,7 +36,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
func (fc *Controller) UpdateFile(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
var fileUpdate schemas.FileUpdate
@ -65,7 +65,7 @@ func (fc *Controller) GetFileByID(c *gin.Context) {
func (fc *Controller) ListFiles(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
fquery := schemas.FileQuery{
PerPage: 500,
@ -90,7 +90,7 @@ func (fc *Controller) ListFiles(c *gin.Context) {
func (fc *Controller) MakeDirectory(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
var payload schemas.MkDir
if err := c.ShouldBindJSON(&payload); err != nil {
@ -118,7 +118,7 @@ func (fc *Controller) CopyFile(c *gin.Context) {
func (fc *Controller) MoveFiles(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
var payload schemas.FileOperation
if err := c.ShouldBindJSON(&payload); err != nil {
@ -136,7 +136,7 @@ func (fc *Controller) MoveFiles(c *gin.Context) {
func (fc *Controller) DeleteFiles(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
var payload schemas.DeleteOperation
if err := c.ShouldBindJSON(&payload); err != nil {
@ -164,7 +164,7 @@ func (fc *Controller) DeleteFileParts(c *gin.Context) {
}
func (fc *Controller) MoveDirectory(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
var payload schemas.DirMove
if err := c.ShouldBindJSON(&payload); err != nil {
@ -181,7 +181,7 @@ func (fc *Controller) MoveDirectory(c *gin.Context) {
}
func (fc *Controller) GetCategoryStats(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
res, err := fc.FileService.GetCategoryStats(userId)
if err != nil {
@ -193,5 +193,9 @@ func (fc *Controller) GetCategoryStats(c *gin.Context) {
}
func (fc *Controller) GetFileStream(c *gin.Context) {
fc.FileService.GetFileStream(c)
fc.FileService.GetFileStream(c, false)
}
func (fc *Controller) GetFileDownload(c *gin.Context) {
fc.FileService.GetFileStream(c, true)
}

View file

@ -4,8 +4,8 @@ import (
"net/http"
"strconv"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/pkg/httputil"
"github.com/divyam234/teldrive/pkg/services"
"github.com/gin-gonic/gin"
)
@ -40,7 +40,7 @@ func (uc *Controller) UploadFile(c *gin.Context) {
}
func (uc *Controller) UploadStats(c *gin.Context) {
userId, _ := services.GetUserAuth(c)
userId, _ := auth.GetUser(c)
days := 7

View file

@ -6,9 +6,9 @@ import (
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/logging"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/services"
"github.com/go-co-op/gocron"
"github.com/jackc/pgx/v5/pgtype"
"go.uber.org/zap"
@ -87,7 +87,8 @@ func (c *CronService) CleanFiles(ctx context.Context) {
}
}
err := services.DeleteTGMessages(ctx, &c.cnf.TG, row.Session, row.ChannelId, row.UserId, ids)
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, row.Session)
err := tgc.DeleteMessages(ctx, client, row.ChannelId, ids)
if err != nil {
c.logger.Errorw("failed to delete messages", err)
@ -122,7 +123,9 @@ func (c *CronService) CleanUploads(ctx context.Context) {
for _, result := range upResults {
if result.Session != "" && len(result.Parts) > 0 {
err := services.DeleteTGMessages(ctx, &c.cnf.TG, result.Session, result.ChannelId, result.UserId, result.Parts)
client, _ := tgc.AuthClient(ctx, &c.cnf.TG, result.Session)
err := tgc.DeleteMessages(ctx, client, result.ChannelId, result.Parts)
if err != nil {
c.logger.Errorw("failed to delete messages", err)
return

View file

@ -1,214 +1,21 @@
package services
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"math"
"sort"
"strconv"
"sync"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/pkg/errors"
"github.com/thoas/go-funk"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
)
type buffer struct {
Buf []byte
}
func (b *buffer) long() (int64, error) {
v, err := b.uint64()
if err != nil {
return 0, err
}
return int64(v), nil
}
func (b *buffer) uint64() (uint64, error) {
const size = 8
if len(b.Buf) < size {
return 0, io.ErrUnexpectedEOF
}
v := binary.LittleEndian.Uint64(b.Buf)
b.Buf = b.Buf[size:]
return v, nil
}
func randInt64() (int64, error) {
var buf [8]byte
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
return 0, err
}
b := &buffer{Buf: buf[:]}
return b.long()
}
type batchResult struct {
Index int
Messages *tg.MessagesChannelMessages
}
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: location,
}
r, err := tgClient.API().UploadGetFile(ctx, req)
if err != nil {
return nil, err
}
switch result := r.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
return nil, fmt.Errorf("unexpected type %T", r)
}
}
func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) {
offset := int64(0)
limit := int64(1024 * 1024)
buff := &bytes.Buffer{}
for {
r, err := getChunk(ctx, tgClient, location, offset, limit)
if err != nil {
return buff, err
}
if len(r) == 0 {
break
}
buff.Write(r)
offset += int64(limit)
}
return buff, nil
}
func GetUserAuth(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
var user *tg.User
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil
})
if err != nil {
return nil, err
}
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
}
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, index int,
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
return &tg.InputMessageID{ID: int(part.ID)}
}).([]tg.InputMessageClass)
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: ids,
}
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
errors <- err
return
}
messages, ok := res.(*tg.MessagesChannelMessages)
if !ok {
errors <- fmt.Errorf("unexpected response type: %T", res)
return
}
results <- batchResult{Index: index, Messages: messages}
}
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
channel, err := GetChannelById(ctx, client, channelId, userID)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
batchSize := 200
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
results := make(chan batchResult, batchCount)
errors := make(chan error, batchCount)
for i := range batchCount {
wg.Add(1)
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
go getTGMessagesBatch(ctx, client, channel, splitParts, i, results, errors, &wg)
}
wg.Wait()
close(results)
close(errors)
for err := range errors {
if err != nil {
return nil, err
}
}
batchResults := []batchResult{}
for result := range results {
batchResults = append(batchResults, result)
}
sort.Slice(batchResults, func(i, j int) bool {
return batchResults[i].Index < batchResults[j].Index
})
allMessages := []tg.MessageClass{}
for _, result := range batchResults {
allMessages = append(allMessages, result.Messages.GetMessages()...)
}
return allMessages, nil
}
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
cache := cache.FromContext(ctx)
parts := []types.Part{}
@ -220,7 +27,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
return parts, nil
}
messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID)
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID)
if err != nil {
return nil, err
@ -230,12 +41,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
part := types.Part{
Location: location,
Size: document.Size,
Salt: file.Parts[i].Salt,
ID: file.Parts[i].ID,
Size: document.Size,
Salt: file.Parts[i].Salt,
}
if file.Encrypted {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
@ -246,27 +56,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
return parts, nil
}
func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) {
channel := &tg.InputChannel{}
inputChannel := &tg.InputChannel{
ChannelID: channelId,
}
channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
if err != nil {
return nil, err
}
if len(channels.GetChats()) == 0 {
return nil, errors.New("no channels found")
}
channel = channels.GetChats()[0].(*tg.Channel).AsInput()
return channel, nil
}
func GetDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) {
cache := cache.FromContext(ctx)
var channelId int64
key := fmt.Sprintf("users:channel:%d", userID)
@ -332,37 +122,3 @@ func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Ses
return &session, nil
}
func DeleteTGMessages(ctx context.Context, cnf *config.TGConfig, session string, channelId, userId int64, ids []int) error {
client, _ := tgc.AuthClient(ctx, cnf, session)
err := tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
if err != nil {
return err
}
batchSize := 100
batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize)))
g, _ := errgroup.WithContext(ctx)
g.SetLimit(8)
for i := 0; i < batchCount; i++ {
start := i * batchSize
end := min((i+1)*batchSize, len(ids))
batchIds := ids[start:end]
g.Go(func() error {
messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds}
_, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest)
return err
})
}
return g.Wait()
})
return err
}

View file

@ -2,7 +2,9 @@ package services
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"mime"
@ -12,6 +14,7 @@ import (
"time"
"github.com/WinterYukky/gorm-extra-clause-plugin/exclause"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/category"
"github.com/divyam234/teldrive/internal/config"
@ -35,6 +38,36 @@ import (
"gorm.io/gorm/clause"
)
type buffer struct {
Buf []byte
}
func (b *buffer) long() (int64, error) {
v, err := b.uint64()
if err != nil {
return 0, err
}
return int64(v), nil
}
func (b *buffer) uint64() (uint64, error) {
const size = 8
if len(b.Buf) < size {
return 0, io.ErrUnexpectedEOF
}
v := binary.LittleEndian.Uint64(b.Buf)
b.Buf = b.Buf[size:]
return v, nil
}
func randInt64() (int64, error) {
var buf [8]byte
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
return 0, err
}
b := &buffer{Buf: buf[:]}
return b.long()
}
type FileService struct {
db *gorm.DB
cnf *config.TGConfig
@ -75,7 +108,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.
channelId := fileIn.ChannelID
if fileIn.ChannelID == 0 {
var err error
channelId, err = GetDefaultChannel(c, fs.db, userId)
channelId, err = getDefaultChannel(c, fs.db, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
}
@ -332,7 +365,9 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
return nil, &types.AppError{Error: err}
}
userId, session := GetUserAuth(c)
_, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session)
ids := []int{}
@ -340,7 +375,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess
ids = append(ids, int(part.ID))
}
err := DeleteTGMessages(c, fs.cnf, session, *file.ChannelID, userId, ids)
err := tgc.DeleteMessages(c, client, *file.ChannelID, ids)
if err != nil {
return nil, &types.AppError{Error: err}
@ -380,7 +415,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
}
userId, session := GetUserAuth(c)
userId, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, fs.cnf, session)
@ -394,20 +429,24 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
newIds := []schemas.Part{}
channelId, err := GetDefaultChannel(c, fs.db, userId)
channelId, err := getDefaultChannel(c, fs.db, userId)
if err != nil {
return nil, &types.AppError{Error: err}
}
err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
user := strconv.FormatInt(userId, 10)
messages, err := getTGMessages(c, client, file.Parts, file.ChannelID, user)
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
messages, err := tgc.GetMessages(c, client.API(), ids, file.ChannelID)
if err != nil {
return err
}
channel, err := GetChannelById(ctx, client, channelId, user)
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
if err != nil {
return err
@ -482,7 +521,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
return mapper.ToFileOut(dbFile), nil
}
func (fs *FileService) GetFileStream(c *gin.Context) {
func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
w := c.Writer
@ -585,7 +624,7 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
disposition := "inline"
if c.Query("d") == "1" {
if download {
disposition = "attachment"
}
@ -601,14 +640,15 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
}
var (
channelUser string
lr io.ReadCloser
channelUser string
lr io.ReadCloser
client *tgc.Client
multiThreads int
)
var client *tgc.Client
multiThreads = fs.cnf.Stream.MultiThreads
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil {
logger.Error("file stream", zap.Error(err))
@ -616,47 +656,36 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
return
}
channelUser = strconv.FormatInt(session.UserId, 10)
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "user", channelUser, "start", start,
"end", end, "fileSize", file.Size)
multiThreads = 0
} else {
var index int
limit := min(len(tokens), fs.cnf.BgBotsLimit)
fs.worker.Set(tokens[:limit], file.ChannelID)
client, index, err = fs.worker.Next(file.ChannelID)
client, _, err = fs.worker.Next(file.ChannelID)
if err != nil {
logger.Error("file stream", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelUser = strings.Split(tokens[index], ":")[0]
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "botNo", index, "start", start,
"end", end, "fileSize", file.Size)
}
if r.Method != "HEAD" {
parts, err := getParts(c, client.Tg, file, channelUser)
parts, err := getParts(c, client.Tg.API(), file, channelUser)
if err != nil {
logger.Error("file stream", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
tgClient := client.Tg.API()
if fs.cnf.Stream.UsePooling {
tgClient = client.Pool.Default(c)
if download {
multiThreads = 0
}
if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, tgClient, parts, start, end, fs.cnf)
lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
} else {
lr, err = reader.NewLinearReader(c, tgClient, parts, start, end, fs.cnf)
lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker)
}
if err != nil {
@ -669,7 +698,10 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
return
}
io.CopyN(w, lr, contentLength)
_, err = io.CopyN(w, lr, contentLength)
if err != nil {
lr.Close()
}
}
}
func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB {

View file

@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/logging"
@ -116,7 +117,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
Code: http.StatusBadRequest}
}
userId, session := GetUserAuth(c)
userId, session := auth.GetUser(c)
uploadId := c.Param("id")
@ -127,7 +128,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
defer fileStream.Close()
if uploadQuery.ChannelID == 0 {
channelId, err = GetDefaultChannel(c, us.db, userId)
channelId, err = getDefaultChannel(c, us.db, userId)
if err != nil {
return nil, &types.AppError{Error: err}
}
@ -174,7 +175,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
channel, err := GetChannelById(ctx, client, channelId, channelUser)
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
if err != nil {
return err

View file

@ -10,6 +10,7 @@ import (
"sync"
"time"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv"
@ -41,7 +42,7 @@ func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV) *UserService {
return &UserService{db: db, cnf: cnf, kv: kv}
}
func (us *UserService) GetProfilePhoto(c *gin.Context) {
_, session := GetUserAuth(c)
_, session := auth.GetUser(c)
client, err := tgc.AuthClient(c, &us.cnf.TG, session)
@ -64,7 +65,7 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
return errors.New("profile not found")
}
location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID}
buff, err := iterContent(c, client, location)
buff, err := tgc.GetMediaContent(c, client.API(), location)
if err != nil {
return err
}
@ -83,13 +84,13 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) {
}
func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.AppError) {
userID, _ := GetUserAuth(c)
userID, _ := auth.GetUser(c)
var (
channelId int64
err error
)
channelId, _ = GetDefaultChannel(c, us.db, userID)
channelId, _ = getDefaultChannel(c, us.db, userID)
tokens, err := getBotsToken(c, us.db, userID, channelId)
@ -103,7 +104,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
cache := cache.FromContext(c)
userId, _ := GetUserAuth(c)
userId, _ := auth.GetUser(c)
var payload schemas.Channel
@ -130,7 +131,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A
}
func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *types.AppError) {
userId, userSession := GetUserAuth(c)
userId, userSession := auth.GetUser(c)
client, _ := tgc.AuthClient(c, &us.cnf.TG, userSession)
@ -185,7 +186,7 @@ func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *type
func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.AppError) {
userId, _ := GetUserAuth(c)
userId, _ := auth.GetUser(c)
session := &models.Session{}
@ -209,7 +210,7 @@ func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.A
}
func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) {
_, session := GetUserAuth(c)
_, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
channels := make(map[int64]*schemas.Channel)
@ -236,7 +237,7 @@ func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppErro
}
func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) {
userId, session := GetUserAuth(c)
userId, session := auth.GetUser(c)
client, _ := tgc.AuthClient(c, &us.cnf.TG, session)
var botsTokens []string
@ -249,7 +250,7 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro
return &schemas.Message{Message: "no bots to add"}, nil
}
channelId, err := GetDefaultChannel(c, us.db, userId)
channelId, err := getDefaultChannel(c, us.db, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
@ -263,9 +264,9 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE
cache := cache.FromContext(c)
userID, _ := GetUserAuth(c)
userID, _ := auth.GetUser(c)
channelId, err := GetDefaultChannel(c, us.db, userID)
channelId, err := getDefaultChannel(c, us.db, userID)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
@ -294,7 +295,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10))
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
if err != nil {
logger.Error("error", zap.Error(err))
@ -309,7 +310,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI
waitChan <- struct{}{}
wg.Add(1)
go func(t string) {
info, err := getBotInfo(c, us.kv, &us.cnf.TG, t)
info, err := tgc.GetBotInfo(c, client, t)
if err != nil {
return
}

View file

@ -3,7 +3,6 @@ package types
import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/gotd/td/session"
"github.com/gotd/td/tg"
)
type AppError struct {
@ -12,10 +11,10 @@ type AppError struct {
}
type Part struct {
Location *tg.InputDocumentFileLocation
DecryptedSize int64
Size int64
Salt string
ID int64
}
type JWTClaims struct {