diff --git a/.github/workflows/build-dev.yml b/.github/workflows/build-dev.yml new file mode 100644 index 0000000..9476052 --- /dev/null +++ b/.github/workflows/build-dev.yml @@ -0,0 +1,36 @@ +name: Build Dev + +on: + push: + branches: + - dev + +permissions: write-all + +jobs: + goreleaser: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set Vars + run: | + echo "IMAGE=ghcr.io/${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV + + - name: Build Image + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ env.IMAGE }}:dev \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4ccc583 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ + +FROM golang:alpine AS builder + +RUN apk add --no-cache git unzip curl make bash + +WORKDIR /app + +COPY go.mod go.sum ./ + +RUN go mod download + +COPY . . + +RUN make build + +FROM scratch + +WORKDIR / + +COPY --from=builder /app/bin/teldrive /teldrive + +EXPOSE 8080 + +ENTRYPOINT ["/teldrive","run","--tg-session-file","/session.db"] \ No newline at end of file diff --git a/Makefile b/Makefile index 2f57e19..12c06cd 100644 --- a/Makefile +++ b/Makefile @@ -11,19 +11,16 @@ FRONTEND_ASSET := https://github.com/divyam234/teldrive-ui/releases/download/v1/ GIT_TAG := $(shell git describe --tags --abbrev=0) GIT_COMMIT := $(shell git rev-parse --short HEAD) GIT_LINK := $(shell git remote get-url origin) -ENV_FILE := $(FRONTEND_DIR)/.env MODULE_PATH := $(shell go list -m) -BUILD_DATE := $(shell $(BUILD_DATE)) - GOOS ?= $(shell go env GOOS) GOARCH ?= $(shell go env GOARCH) +VERSION:= $(GIT_TAG) +BINARY_EXTENSION := .PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version all: build - - frontend: @echo "Extract UI" ifeq ($(OS),Windows_NT) @@ -40,10 +37,13 @@ else rm -rf teldrive-ui.zip endif +ifeq ($(OS),Windows_NT) + BINARY_EXTENSION := .exe +endif backend: @echo "Building backend for $(GOOS)/$(GOARCH)..." - go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(GIT_TAG) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION) + go build -trimpath -ldflags "-s -w -X $(MODULE_PATH)/internal/config.Version=$(VERSION) -extldflags=-static" -o $(BUILD_DIR)/$(APP_NAME)$(BINARY_EXTENSION) build: frontend backend @echo "Building complete." diff --git a/api/router.go b/api/router.go index 9fd2342..eabc234 100644 --- a/api/router.go +++ b/api/router.go @@ -28,6 +28,8 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi files.PATCH(":fileID", authmiddleware, c.UpdateFile) files.HEAD(":fileID/stream/:fileName", c.GetFileStream) files.GET(":fileID/stream/:fileName", c.GetFileStream) + files.HEAD(":fileID/download/:fileName", c.GetFileDownload) + files.GET(":fileID/download/:fileName", c.GetFileDownload) files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts) files.GET("/category/stats", authmiddleware, c.GetCategoryStats) files.POST("/move", authmiddleware, c.MoveFiles) diff --git a/cmd/run.go b/cmd/run.go index 0db6dfe..76a7385 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -89,12 +89,11 @@ func NewRun() *cobra.Command { runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads") runCmd.Flags().IntVar(&config.TG.Uploads.MaxRetries, "tg-uploads-max-retries", 10, "Uploads Retries") runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size") - runCmd.Flags().BoolVar(&config.TG.Stream.BufferReader, "tg-stream-buffer-reader", false, "Async Buffered reader for fast streaming") - runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 16, "No of Stream buffers") - runCmd.Flags().BoolVar(&config.TG.Stream.UseMmap, "tg-stream-use-mmap", false, "Use mmap for stream buffers") - runCmd.Flags().BoolVar(&config.TG.Stream.UsePooling, "tg-stream-use-pooling", false, "Use session pooling for stream workers") duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout") duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration") + runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") + runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") + duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout") runCmd.MarkFlagRequired("tg-app-id") runCmd.MarkFlagRequired("tg-app-hash") runCmd.MarkFlagRequired("db-data-source") @@ -162,11 +161,11 @@ func initViperConfig(cmd *cobra.Command) error { viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) viper.AutomaticEnv() viper.ReadInConfig() - bindFlagsRecursive(cmd.Flags(), "", reflect.ValueOf(config.Config{})) + bindFlags(cmd.Flags(), "", reflect.ValueOf(config.Config{})) return nil } -func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) { +func bindFlags(flags *pflag.FlagSet, prefix string, v reflect.Value) { t := v.Type() if t.Kind() == reflect.Ptr { t = t.Elem() @@ -175,7 +174,7 @@ func bindFlagsRecursive(flags *pflag.FlagSet, prefix string, v reflect.Value) { field := t.Field(i) switch field.Type.Kind() { case reflect.Struct: - bindFlagsRecursive(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i)) + bindFlags(flags, fmt.Sprintf("%s.%s", prefix, strings.ToLower(field.Name)), v.Field(i)) default: newPrefix := prefix[1:] newName := modifyFlag(field.Name) diff --git a/config.sample.toml b/config.sample.toml index d496eba..f449450 100644 --- a/config.sample.toml +++ b/config.sample.toml @@ -43,8 +43,6 @@ retention = "7d" threads = 8 [tg.stream] - buffer-reader = false - buffers = 6 - use-mmap = false - use-pooling= false + multi-threads = 0 + buffers = 16 diff --git a/go.mod b/go.mod index 0b3a927..0b92ec9 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/magiconair/properties v1.8.7 github.com/mitchellh/go-homedir v1.1.0 github.com/pkg/errors v0.9.1 - github.com/rclone/rclone v1.67.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 @@ -97,7 +96,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pressly/goose/v3 v3.20.0 + github.com/pressly/goose/v3 v3.21.1 github.com/segmentio/asm v1.2.0 // indirect github.com/stretchr/testify v1.9.0 github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 82093fc..eb5863d 100644 --- a/go.sum +++ b/go.sum @@ -145,8 +145,8 @@ github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbW github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= -github.com/microsoft/go-mssqldb v1.7.0 h1:sgMPW0HA6Ihd37Yx0MzHyKD726C2kY/8KJsQtXHNaAs= -github.com/microsoft/go-mssqldb v1.7.0/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= +github.com/microsoft/go-mssqldb v1.7.1 h1:KU/g8aWeM3Hx7IMOFpiwYiUkU+9zeISb4+tx3ScVfsM= +github.com/microsoft/go-mssqldb v1.7.1/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -166,10 +166,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pressly/goose/v3 v3.20.0 h1:uPJdOxF/Ipj7ABVNOAMJXSxwFXZGwMGHNqjC8e61VA0= -github.com/pressly/goose/v3 v3.20.0/go.mod h1:BRfF2GcG4FTG12QfdBVy3q1yveaf4ckL9vWwEcIO3lA= -github.com/rclone/rclone v1.67.0 h1:yLRNgHEG2vQ60HCuzFqd0hYwKCRuWuvPUhvhMJ2jI5E= -github.com/rclone/rclone v1.67.0/go.mod h1:Cb3Ar47M/SvwfhAjZTbVXdtrP/JLtPFCq2tkdtBVC6w= +github.com/pressly/goose/v3 v3.21.1 h1:5SSAKKWej8LVVzNLuT6KIvP1eFDuPvxa+B6H0w78buQ= +github.com/pressly/goose/v3 v3.21.1/go.mod h1:sqthmzV8PitchEkjecFJII//l43dLOCzfWh8pHEe+vE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= diff --git a/internal/auth/jwe.go b/internal/auth/jwe.go index 60dd05c..c61fea7 100644 --- a/internal/auth/jwe.go +++ b/internal/auth/jwe.go @@ -2,8 +2,10 @@ package auth import ( "encoding/json" + "strconv" "github.com/divyam234/teldrive/pkg/types" + "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v3" ) @@ -59,3 +61,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) { return jwtToken, nil } + +func GetUser(c *gin.Context) (int64, string) { + val, _ := c.Get("jwtUser") + jwtUser := val.(*types.JWTClaims) + userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64) + return userId, jwtUser.TgSession +} diff --git a/internal/config/config.go b/internal/config/config.go index 481a30e..5f27ed5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,10 +43,9 @@ type TGConfig struct { Retention time.Duration } Stream struct { - BufferReader bool + MultiThreads int Buffers int - UseMmap bool - UsePooling bool + ChunkTimeout time.Duration } } diff --git a/internal/reader/async-reader.go b/internal/reader/async-reader.go deleted file mode 100644 index 3e76fcd..0000000 --- a/internal/reader/async-reader.go +++ /dev/null @@ -1,206 +0,0 @@ -// taken from rclone async reader implmentation -package reader - -import ( - "context" - "errors" - "io" - "sync" - "time" - - "github.com/rclone/rclone/lib/pool" -) - -const ( - BufferSize = 1024 * 1024 - softStartInitial = 4 * 1024 - bufferCacheSize = 64 - bufferCacheFlushTime = 5 * time.Second -) - -var ErrorStreamAbandoned = errors.New("stream abandoned") - -type AsyncReader struct { - in io.ReadCloser - ready chan *buffer - token chan struct{} - exit chan struct{} - buffers int - err error - cur *buffer - exited chan struct{} - closed bool - mu sync.Mutex -} - -func NewAsyncReader(ctx context.Context, rd io.ReadCloser, buffers int) (*AsyncReader, error) { - if buffers <= 0 { - return nil, errors.New("number of buffers too small") - } - if rd == nil { - return nil, errors.New("nil reader supplied") - } - a := &AsyncReader{} - a.init(rd, buffers) - return a, nil -} - -func (a *AsyncReader) init(rd io.ReadCloser, buffers int) { - a.in = rd - a.ready = make(chan *buffer, buffers) - a.token = make(chan struct{}, buffers) - a.exit = make(chan struct{}) - a.exited = make(chan struct{}) - a.buffers = buffers - a.cur = nil - - for i := 0; i < buffers; i++ { - a.token <- struct{}{} - } - - go func() { - defer close(a.exited) - defer close(a.ready) - for { - select { - case <-a.token: - b := a.getBuffer() - err := b.read(a.in) - a.ready <- b - if err != nil { - return - } - case <-a.exit: - return - } - } - }() -} - -var bufferPool *pool.Pool -var bufferPoolOnce sync.Once - -func (a *AsyncReader) putBuffer(b *buffer) { - bufferPool.Put(b.buf) - b.buf = nil -} - -func (a *AsyncReader) getBuffer() *buffer { - bufferPoolOnce.Do(func() { - bufferPool = pool.New(bufferCacheFlushTime, BufferSize, bufferCacheSize, false) - }) - return &buffer{ - buf: bufferPool.Get(), - } -} - -func (a *AsyncReader) fill() (err error) { - if a.cur.isEmpty() { - if a.cur != nil { - a.putBuffer(a.cur) - a.token <- struct{}{} - a.cur = nil - } - b, ok := <-a.ready - if !ok { - if a.err == nil { - return ErrorStreamAbandoned - } - return a.err - } - a.cur = b - } - return nil -} - -func (a *AsyncReader) Read(p []byte) (n int, err error) { - a.mu.Lock() - defer a.mu.Unlock() - - err = a.fill() - if err != nil { - return 0, err - } - - n = copy(p, a.cur.buffer()) - a.cur.increment(n) - - if a.cur.isEmpty() { - a.err = a.cur.err - return n, a.err - } - return n, nil -} - -func (a *AsyncReader) StopBuffering() { - select { - case <-a.exit: - return - default: - } - close(a.exit) - <-a.exited -} - -func (a *AsyncReader) Abandon() { - a.StopBuffering() - a.mu.Lock() - defer a.mu.Unlock() - if a.cur != nil { - a.putBuffer(a.cur) - a.cur = nil - } - for b := range a.ready { - a.putBuffer(b) - } -} - -func (a *AsyncReader) Close() (err error) { - a.Abandon() - if a.closed { - return nil - } - a.closed = true - return a.in.Close() -} - -type buffer struct { - buf []byte - err error - offset int -} - -func (b *buffer) isEmpty() bool { - if b == nil { - return true - } - if len(b.buf)-b.offset <= 0 { - return true - } - return false -} - -func (b *buffer) readFill(r io.Reader, buf []byte) (n int, err error) { - var nn int - for n < len(buf) && err == nil { - nn, err = r.Read(buf[n:]) - n += nn - } - return n, err -} - -func (b *buffer) read(rd io.Reader) error { - var n int - n, b.err = b.readFill(rd, b.buf) - b.buf = b.buf[0:n] - b.offset = 0 - return b.err -} - -func (b *buffer) buffer() []byte { - return b.buf[b.offset:] -} - -func (b *buffer) increment(n int) { - b.offset += n -} diff --git a/internal/reader/decrypted-reader.go b/internal/reader/decrypted_reader.go similarity index 56% rename from internal/reader/decrypted-reader.go rename to internal/reader/decrypted_reader.go index 57f51a0..551bda1 100644 --- a/internal/reader/decrypted-reader.go +++ b/internal/reader/decrypted_reader.go @@ -6,36 +6,47 @@ import ( "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/crypt" + "github.com/divyam234/teldrive/internal/tgc" "github.com/divyam234/teldrive/pkg/types" - "github.com/gotd/td/tg" ) type decrpytedReader struct { - ctx context.Context - parts []types.Part - ranges []types.Range - pos int - client *tg.Client - reader io.ReadCloser - limit int64 - err error - config *config.TGConfig + ctx context.Context + parts []types.Part + ranges []types.Range + pos int + reader io.ReadCloser + limit int64 + config *config.TGConfig + channelId int64 + worker *tgc.StreamWorker + client *tgc.Client + fileId string + concurrency int } func NewDecryptedReader( ctx context.Context, - client *tg.Client, + fileId string, parts []types.Part, start, end int64, - config *config.TGConfig) (io.ReadCloser, error) { + channelId int64, + config *config.TGConfig, + concurrency int, + client *tgc.Client, + worker *tgc.StreamWorker) (io.ReadCloser, error) { r := &decrpytedReader{ - ctx: ctx, - parts: parts, - client: client, - limit: end - start + 1, - ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize), - config: config, + ctx: ctx, + parts: parts, + limit: end - start + 1, + ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize), + config: config, + client: client, + worker: worker, + channelId: channelId, + fileId: fileId, + concurrency: concurrency, } res, err := r.nextPart() @@ -51,30 +62,24 @@ func NewDecryptedReader( func (r *decrpytedReader) Read(p []byte) (n int, err error) { - if r.err != nil { - return 0, r.err - } - if r.limit <= 0 { return 0, io.EOF } n, err = r.reader.Read(p) - - if err == nil { - r.limit -= int64(n) - } - + r.limit -= int64(n) if err == io.EOF { if r.limit > 0 { err = nil + if r.reader != nil { + r.reader.Close() + } } r.pos++ if r.pos < len(r.ranges) { r.reader, err = r.nextPart() } } - r.err = err return } @@ -89,7 +94,6 @@ func (r *decrpytedReader) Close() (err error) { func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { - location := r.parts[r.ranges[r.pos].PartNo].Location start := r.ranges[r.pos].Start end := r.ranges[r.pos].End salt := r.parts[r.ranges[r.pos].PartNo].Salt @@ -99,21 +103,15 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { func(ctx context.Context, underlyingOffset, underlyingLimit int64) (io.ReadCloser, error) { - var end int64 if underlyingLimit >= 0 { end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1) } - rd, err := newTGReader(r.ctx, r.client, location, underlyingOffset, end) - if err != nil { - return nil, err - } - if r.config.Stream.BufferReader { - return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers) - - } - return rd, nil + chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, + fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, + client: r.client, concurrency: r.concurrency} + return newTGReader(r.ctx, start, end, r.config, chunkSrc) }, start, end-start+1) diff --git a/internal/reader/reader.go b/internal/reader/reader.go index 754ce8a..8db2b76 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -5,8 +5,8 @@ import ( "io" "github.com/divyam234/teldrive/internal/config" + "github.com/divyam234/teldrive/internal/tgc" "github.com/divyam234/teldrive/pkg/types" - "github.com/gotd/td/tg" ) func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { @@ -37,31 +37,42 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { } type linearReader struct { - ctx context.Context - parts []types.Part - ranges []types.Range - pos int - client *tg.Client - reader io.ReadCloser - limit int64 - err error - config *config.TGConfig + ctx context.Context + parts []types.Part + ranges []types.Range + pos int + reader io.ReadCloser + limit int64 + config *config.TGConfig + channelId int64 + worker *tgc.StreamWorker + client *tgc.Client + fileId string + concurrency int } func NewLinearReader(ctx context.Context, - client *tg.Client, + fileId string, parts []types.Part, start, end int64, + channelId int64, config *config.TGConfig, + concurrency int, + client *tgc.Client, + worker *tgc.StreamWorker, ) (reader io.ReadCloser, err error) { r := &linearReader{ - ctx: ctx, - parts: parts, - client: client, - limit: end - start + 1, - ranges: calculatePartByteRanges(start, end, parts[0].Size), - config: config, + ctx: ctx, + parts: parts, + limit: end - start + 1, + ranges: calculatePartByteRanges(start, end, parts[0].Size), + config: config, + client: client, + worker: worker, + channelId: channelId, + fileId: fileId, + concurrency: concurrency, } r.reader, err = r.nextPart() @@ -73,25 +84,22 @@ func NewLinearReader(ctx context.Context, return r, nil } -func (r *linearReader) Read(p []byte) (n int, err error) { - - if r.err != nil { - return 0, r.err - } +func (r *linearReader) Read(p []byte) (int, error) { if r.limit <= 0 { return 0, io.EOF } - n, err = r.reader.Read(p) + n, err := r.reader.Read(p) - if err == nil { - r.limit -= int64(n) - } + r.limit -= int64(n) if err == io.EOF { if r.limit > 0 { err = nil + if r.reader != nil { + r.reader.Close() + } } r.pos++ if r.pos < len(r.ranges) { @@ -99,24 +107,18 @@ func (r *linearReader) Read(p []byte) (n int, err error) { } } - r.err = err - return + return n, err } func (r *linearReader) nextPart() (io.ReadCloser, error) { - location := r.parts[r.ranges[r.pos].PartNo].Location - startByte := r.ranges[r.pos].Start - endByte := r.ranges[r.pos].End - rd, err := newTGReader(r.ctx, r.client, location, startByte, endByte) - if err != nil { - return nil, err - } - if r.config.Stream.BufferReader { - return NewAsyncReader(r.ctx, rd, r.config.Stream.Buffers) + start := r.ranges[r.pos].Start + end := r.ranges[r.pos].End - } - return rd, nil + chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, + fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, + client: r.client, concurrency: r.concurrency} + return newTGReader(r.ctx, start, end, r.config, chunkSrc) } diff --git a/internal/reader/tg_reader.go b/internal/reader/tg_reader.go new file mode 100644 index 0000000..4437d2e --- /dev/null +++ b/internal/reader/tg_reader.go @@ -0,0 +1,311 @@ +package reader + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/divyam234/teldrive/internal/config" + "github.com/divyam234/teldrive/internal/tgc" + "github.com/gotd/td/tg" + "golang.org/x/sync/errgroup" +) + +var ErrorStreamAbandoned = errors.New("stream abandoned") + +type ChunkSource interface { + Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) + ChunkSize(start, end int64) int64 +} + +type chunkSource struct { + channelId int64 + worker *tgc.StreamWorker + fileId string + partId int64 + concurrency int + client *tgc.Client +} + +func (c *chunkSource) ChunkSize(start, end int64) int64 { + return tgc.CalculateChunkSize(start, end) +} + +func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) { + var ( + location *tg.InputDocumentFileLocation + err error + client *tgc.Client + ) + + client = c.client + + if c.concurrency > 0 { + client, _, _ = c.worker.Next(c.channelId) + } + location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId) + + if err != nil { + return nil, err + } + + return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit) + +} + +type tgReader struct { + ctx context.Context + offset int64 + limit int64 + chunkSize int64 + bufferChan chan *buffer + done chan struct{} + cur *buffer + err chan error + mu sync.Mutex + concurrency int + leftCut int64 + rightCut int64 + totalParts int + currentPart int + closed bool + timeout time.Duration + chunkSrc ChunkSource +} + +func newTGReader( + ctx context.Context, + start int64, + end int64, + config *config.TGConfig, + chunkSrc ChunkSource, + +) (*tgReader, error) { + + chunkSize := chunkSrc.ChunkSize(start, end) + + offset := start - (start % chunkSize) + + r := &tgReader{ + ctx: ctx, + limit: end - start + 1, + bufferChan: make(chan *buffer, config.Stream.Buffers), + concurrency: config.Stream.MultiThreads, + leftCut: start - offset, + rightCut: (end % chunkSize) + 1, + totalParts: int((end - offset + chunkSize) / chunkSize), + offset: offset, + chunkSize: chunkSize, + chunkSrc: chunkSrc, + timeout: config.Stream.ChunkTimeout, + done: make(chan struct{}, 1), + err: make(chan error, 1), + } + + if r.concurrency == 0 { + r.currentPart = 1 + go r.fillBufferSequentially() + } else { + go r.fillBufferConcurrently() + } + + return r, nil +} + +func (r *tgReader) Close() error { + close(r.done) + close(r.err) + return nil +} + +func (r *tgReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.limit <= 0 { + return 0, io.EOF + } + + if r.cur.isEmpty() { + if r.cur != nil { + r.cur = nil + } + select { + case cur, ok := <-r.bufferChan: + if !ok && r.limit > 0 { + return 0, ErrorStreamAbandoned + } + r.cur = cur + + case err := <-r.err: + return 0, fmt.Errorf("error reading chunk: %w", err) + case <-r.ctx.Done(): + return 0, r.ctx.Err() + + } + } + + n := copy(p, r.cur.buffer()) + r.cur.increment(n) + r.limit -= int64(n) + + if r.limit <= 0 { + return n, io.EOF + } + + return n, nil +} + +func (r *tgReader) fillBufferConcurrently() error { + + var mapMu sync.Mutex + + bufferMap := make(map[int]*buffer) + + defer func() { + close(r.bufferChan) + r.closed = true + for i := range bufferMap { + delete(bufferMap, i) + } + }() + + cb := func(ctx context.Context, i int) func() error { + return func() error { + + chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize) + if err != nil { + return err + } + if r.totalParts == 1 { + chunk = chunk[r.leftCut:r.rightCut] + } else if r.currentPart+i+1 == 1 { + chunk = chunk[r.leftCut:] + } else if r.currentPart+i+1 == r.totalParts { + chunk = chunk[:r.rightCut] + } + buf := &buffer{buf: chunk} + mapMu.Lock() + bufferMap[i] = buf + mapMu.Unlock() + return nil + } + } + + for { + + g := errgroup.Group{} + + g.SetLimit(r.concurrency) + + for i := range r.concurrency { + if r.currentPart+i+1 <= r.totalParts { + g.Go(cb(r.ctx, i)) + } + } + + done := make(chan error, 1) + + go func() { + done <- g.Wait() + }() + + select { + case err := <-done: + if err != nil { + r.err <- err + return nil + } else { + for i := range r.concurrency { + if r.currentPart+i+1 <= r.totalParts { + r.bufferChan <- bufferMap[i] + } + } + r.currentPart += r.concurrency + r.offset += r.chunkSize * int64(r.concurrency) + for i := range bufferMap { + delete(bufferMap, i) + } + if r.currentPart >= r.totalParts { + return nil + } + } + case <-time.After(r.timeout): + return nil + case <-r.done: + return nil + case <-r.ctx.Done(): + return r.ctx.Err() + } + + } +} + +func (r *tgReader) fillBufferSequentially() error { + + defer close(r.bufferChan) + + fetchChunk := func(ctx context.Context) (*buffer, error) { + chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize) + if err != nil { + return nil, err + } + if r.totalParts == 1 { + chunk = chunk[r.leftCut:r.rightCut] + } else if r.currentPart == 1 { + chunk = chunk[r.leftCut:] + } else if r.currentPart == r.totalParts { + chunk = chunk[:r.rightCut] + } + return &buffer{buf: chunk}, nil + } + + for { + select { + case <-r.done: + return nil + case <-r.ctx.Done(): + return r.ctx.Err() + case <-time.After(r.timeout): + return nil + default: + buf, err := fetchChunk(r.ctx) + if err != nil { + r.err <- err + return nil + } + r.bufferChan <- buf + r.currentPart++ + r.offset += r.chunkSize + if r.currentPart > r.totalParts { + return nil + } + } + } +} + +type buffer struct { + buf []byte + offset int +} + +func (b *buffer) isEmpty() bool { + if b == nil { + return true + } + if len(b.buf)-b.offset <= 0 { + return true + } + return false +} + +func (b *buffer) buffer() []byte { + return b.buf[b.offset:] +} + +func (b *buffer) increment(n int) { + b.offset += n +} diff --git a/internal/reader/tg_reader_test.go b/internal/reader/tg_reader_test.go new file mode 100644 index 0000000..0dcc93b --- /dev/null +++ b/internal/reader/tg_reader_test.go @@ -0,0 +1,142 @@ +package reader + +import ( + "context" + "crypto/rand" + "io" + "testing" + "time" + + "github.com/divyam234/teldrive/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type testChunkSource struct { + buffer []byte +} + +func (m *testChunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) { + return m.buffer[offset : offset+limit], nil +} + +func (m *testChunkSource) ChunkSize(start, end int64) int64 { + return 1 +} + +type testChunkSourceTimeout struct { + buffer []byte +} + +func (m *testChunkSourceTimeout) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) { + if offset == 8 { + time.Sleep(2 * time.Second) + } + return m.buffer[offset : offset+limit], nil +} + +func (m *testChunkSourceTimeout) ChunkSize(start, end int64) int64 { + return 1 +} + +type TestSuite struct { + suite.Suite + config *config.TGConfig +} + +func (suite *TestSuite) SetupTest() { + suite.config = &config.TGConfig{Stream: struct { + MultiThreads int + Buffers int + ChunkTimeout time.Duration + }{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}} +} + +func (suite *TestSuite) TestFullRead() { + ctx := context.Background() + start := int64(0) + end := int64(99) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSource{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + test_data, err := io.ReadAll(reader) + assert.Equal(suite.T(), nil, err) + assert.Equal(suite.T(), data[start:end+1], test_data) +} + +func (suite *TestSuite) TestPartialRead() { + ctx := context.Background() + start := int64(0) + end := int64(65) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSource{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + test_data, err := io.ReadAll(reader) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), data[start:end+1], test_data) +} + +func (suite *TestSuite) TestTimeout() { + ctx := context.Background() + start := int64(0) + end := int64(65) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSourceTimeout{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + test_data, err := io.ReadAll(reader) + assert.Greater(suite.T(), len(test_data), 0) + assert.Equal(suite.T(), err, ErrorStreamAbandoned) +} + +func (suite *TestSuite) TestClose() { + ctx := context.Background() + start := int64(0) + end := int64(65) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSource{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + _, err = io.ReadAll(reader) + assert.NoError(suite.T(), err) + assert.NoError(suite.T(), reader.Close()) +} + +func (suite *TestSuite) TestCancellation() { + ctx, cancel := context.WithCancel(context.Background()) + start := int64(0) + end := int64(65) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSource{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + cancel() + _, err = io.ReadAll(reader) + assert.Equal(suite.T(), err, context.Canceled) + assert.Equal(suite.T(), len(reader.bufferChan), 0) +} + +func (suite *TestSuite) TestCancellationWithTimeout() { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + _ = cancel + start := int64(0) + end := int64(65) + data := make([]byte, 100) + rand.Read(data) + chunkSrc := &testChunkSourceTimeout{buffer: data} + reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + assert.NoError(suite.T(), err) + _, err = io.ReadAll(reader) + assert.Equal(suite.T(), err, context.DeadlineExceeded) + assert.Equal(suite.T(), len(reader.bufferChan), 0) +} +func Test(t *testing.T) { + suite.Run(t, new(TestSuite)) +} diff --git a/internal/reader/tgreader.go b/internal/reader/tgreader.go deleted file mode 100644 index 2936512..0000000 --- a/internal/reader/tgreader.go +++ /dev/null @@ -1,144 +0,0 @@ -package reader - -import ( - "context" - "fmt" - "io" - - "github.com/gotd/td/tg" -) - -type tgReader struct { - ctx context.Context - client *tg.Client - location *tg.InputDocumentFileLocation - start int64 - end int64 - next func() ([]byte, error) - buffer []byte - limit int64 - chunkSize int64 - i int64 -} - -func calculateChunkSize(start, end int64) int64 { - chunkSize := int64(1024 * 1024) - - for chunkSize > 1024 && chunkSize > (end-start) { - chunkSize /= 2 - } - - return chunkSize -} - -func newTGReader( - ctx context.Context, - client *tg.Client, - location *tg.InputDocumentFileLocation, - start int64, - end int64, - -) (io.ReadCloser, error) { - - r := &tgReader{ - ctx: ctx, - location: location, - client: client, - start: start, - end: end, - chunkSize: calculateChunkSize(start, end), - limit: end - start + 1, - } - r.next = r.partStream() - return r, nil -} - -func (r *tgReader) Read(p []byte) (n int, err error) { - - if r.limit <= 0 { - return 0, io.EOF - } - - if r.i >= int64(len(r.buffer)) { - r.buffer, err = r.next() - if err != nil { - return 0, err - } - if len(r.buffer) == 0 { - r.next = r.partStream() - r.buffer, err = r.next() - if err != nil { - return 0, err - } - - } - r.i = 0 - } - n = copy(p, r.buffer[r.i:]) - r.i += int64(n) - r.limit -= int64(n) - - return -} - -func (*tgReader) Close() error { - return nil -} - -func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) { - - req := &tg.UploadGetFileRequest{ - Offset: offset, - Limit: int(limit), - Location: r.location, - Precise: true, - } - - res, err := r.client.UploadGetFile(r.ctx, req) - - if err != nil { - return nil, err - } - - switch result := res.(type) { - case *tg.UploadFile: - return result.Bytes, nil - default: - return nil, fmt.Errorf("unexpected type %T", r) - } -} - -func (r *tgReader) partStream() func() ([]byte, error) { - - start := r.start - end := r.end - offset := start - (start % r.chunkSize) - - leftCut := start - offset - rightCut := (end % r.chunkSize) + 1 - totalParts := int((end - offset + r.chunkSize) / r.chunkSize) - currentPart := 1 - - return func() ([]byte, error) { - if currentPart > totalParts { - return make([]byte, 0), nil - } - res, err := r.chunk(offset, r.chunkSize) - if err != nil { - return nil, err - } - if len(res) == 0 { - return res, nil - } else if totalParts == 1 { - res = res[leftCut:rightCut] - } else if currentPart == 1 { - res = res[leftCut:] - } else if currentPart == totalParts { - res = res[:rightCut] - } - - currentPart++ - offset += r.chunkSize - return res, nil - } -} diff --git a/internal/tgc/helpers.go b/internal/tgc/helpers.go new file mode 100644 index 0000000..a6d0a1e --- /dev/null +++ b/internal/tgc/helpers.go @@ -0,0 +1,239 @@ +package tgc + +import ( + "bytes" + "context" + "errors" + "fmt" + "math" + "runtime" + "sync" + + "github.com/divyam234/teldrive/internal/cache" + "github.com/divyam234/teldrive/pkg/types" + "github.com/gotd/td/telegram" + "github.com/gotd/td/tg" + "golang.org/x/sync/errgroup" +) + +var ( + ErrInValidChannelID = errors.New("invalid channel id") + ErrInvalidChannelMessages = errors.New("invalid channel messages") +) + +func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*tg.InputChannel, error) { + inputChannel := &tg.InputChannel{ + ChannelID: channelId, + } + channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel}) + + if err != nil { + return nil, err + } + + if len(channels.GetChats()) == 0 { + return nil, ErrInValidChannelID + } + return channels.GetChats()[0].(*tg.Channel).AsInput(), nil +} + +func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int64, ids []int) error { + + return RunWithAuth(ctx, client, "", func(ctx context.Context) error { + channel, err := GetChannelById(ctx, client.API(), channelId) + + if err != nil { + return err + } + + batchSize := 100 + + batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize))) + + g, _ := errgroup.WithContext(ctx) + + g.SetLimit(runtime.NumCPU()) + + for i := 0; i < batchCount; i++ { + start := i * batchSize + end := min((i+1)*batchSize, len(ids)) + batchIds := ids[start:end] + g.Go(func() error { + messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds} + _, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest) + return err + }) + } + return g.Wait() + }) +} + +func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) { + + msgIds := []tg.InputMessageClass{} + + for _, id := range ids { + msgIds = append(msgIds, &tg.InputMessageID{ID: id}) + } + + messageRequest := tg.ChannelsGetMessagesRequest{ + Channel: channel, + ID: msgIds, + } + + res, err := client.ChannelsGetMessages(ctx, &messageRequest) + + if err != nil { + return nil, err + } + + return res, nil + +} + +func GetMessages(ctx context.Context, client *tg.Client, ids []int, channelId int64) ([]tg.MessageClass, error) { + + channel, err := GetChannelById(ctx, client, channelId) + + if err != nil { + return nil, err + } + + batchSize := 200 + + batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize))) + + g, _ := errgroup.WithContext(ctx) + + g.SetLimit(runtime.NumCPU()) + + messageMap := make(map[int]*tg.MessagesChannelMessages) + + var mapMu sync.Mutex + + for i := range batchCount { + g.Go(func() error { + splitIds := ids[i*batchSize : min((i+1)*batchSize, len(ids))] + res, err := getTGMessagesBatch(ctx, client, channel, splitIds) + if err != nil { + return err + } + messages, ok := res.(*tg.MessagesChannelMessages) + if !ok { + return ErrInvalidChannelMessages + } + mapMu.Lock() + messageMap[i] = messages + mapMu.Unlock() + return nil + }) + + } + + if err = g.Wait(); err != nil { + return nil, err + } + + allMessages := []tg.MessageClass{} + + for i := range batchCount { + allMessages = append(allMessages, messageMap[i].Messages...) + } + + return allMessages, nil +} + +func GetChunk(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) { + req := &tg.UploadGetFileRequest{ + Offset: offset, + Limit: int(limit), + Location: location, + Precise: true, + } + + r, err := client.UploadGetFile(ctx, req) + + if err != nil { + return nil, err + } + + switch result := r.(type) { + case *tg.UploadFile: + return result.Bytes, nil + default: + return nil, fmt.Errorf("unexpected type %T", r) + } +} + +func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) { + offset := int64(0) + limit := int64(1024 * 1024) + buff := &bytes.Buffer{} + for { + r, err := GetChunk(ctx, client, location, offset, limit) + if err != nil { + return buff, err + } + if len(r) == 0 { + break + } + buff.Write(r) + offset += int64(limit) + } + return buff, nil +} + +func GetBotInfo(ctx context.Context, client *telegram.Client, token string) (*types.BotInfo, error) { + var user *tg.User + err := RunWithAuth(ctx, client, token, func(ctx context.Context) error { + user, _ = client.Self(ctx) + return nil + }) + + if err != nil { + return nil, err + } + return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil +} + +func GetLocation(ctx context.Context, client *Client, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) { + + cache := cache.FromContext(ctx) + + key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId) + + err = cache.Get(key, location) + + if err != nil { + channel, err := GetChannelById(ctx, client.Tg.API(), channelId) + + if err != nil { + return nil, err + } + messageRequest := tg.ChannelsGetMessagesRequest{ + Channel: channel, + ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}}, + } + + res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest) + if err != nil { + return nil, err + } + messages, _ := res.(*tg.MessagesChannelMessages) + item := messages.Messages[0].(*tg.Message) + media := item.Media.(*tg.MessageMediaDocument) + document := media.Document.(*tg.Document) + location = document.AsInputDocumentFileLocation() + cache.Set(key, location, 3600) + } + return location, nil +} + +func CalculateChunkSize(start, end int64) int64 { + chunkSize := int64(1024 * 1024) + + for chunkSize > 1024 && chunkSize > (end-start) { + chunkSize /= 2 + } + return chunkSize +} diff --git a/internal/tgc/tgc.go b/internal/tgc/tgc.go index b8b092b..5676180 100644 --- a/internal/tgc/tgc.go +++ b/internal/tgc/tgc.go @@ -55,8 +55,8 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa LangCode: config.LangCode, }, SessionStorage: storage, - RetryInterval: 5 * time.Second, - MaxRetries: 5, + RetryInterval: 2 * time.Second, + MaxRetries: 10, DialTimeout: 10 * time.Second, Middlewares: middlewares, UpdateHandler: handler, diff --git a/internal/tgc/workers.go b/internal/tgc/workers.go index 8a38794..5b42539 100644 --- a/internal/tgc/workers.go +++ b/internal/tgc/workers.go @@ -2,11 +2,11 @@ package tgc import ( "context" + "strings" "sync" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/kv" - "github.com/divyam234/teldrive/internal/pool" "github.com/gotd/td/telegram" ) @@ -42,9 +42,9 @@ func NewUploadWorker() *UploadWorker { type Client struct { Tg *telegram.Client - Pool pool.Pool Stop StopFunc Status string + UserId string } type StreamWorker struct { @@ -69,10 +69,7 @@ func (w *StreamWorker) Set(bots []string, channelId int64) { for _, token := range bots { middlewares := Middlewares(w.cnf, 5) client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...) - c := &Client{Tg: client, Status: "idle"} - if w.cnf.Stream.UsePooling { - c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...) - } + c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]} w.clients[channelId] = append(w.clients[channelId], c) } w.currIdx[channelId] = 0 @@ -108,9 +105,6 @@ func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) middlewares := Middlewares(w.cnf, 5) client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) c := &Client{Tg: client, Status: "idle"} - if w.cnf.Stream.UsePooling { - c.Pool = pool.NewPool(client, int64(w.cnf.PoolSize), middlewares...) - } w.clients[userId] = append(w.clients[userId], c) } nextClient := w.clients[userId][0] diff --git a/pkg/controller/file.go b/pkg/controller/file.go index cdc5f53..af74eb5 100644 --- a/pkg/controller/file.go +++ b/pkg/controller/file.go @@ -3,11 +3,11 @@ package controller import ( "net/http" + "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/logging" "github.com/divyam234/teldrive/pkg/httputil" "github.com/divyam234/teldrive/pkg/schemas" - "github.com/divyam234/teldrive/pkg/services" "github.com/gin-gonic/gin" ) @@ -23,7 +23,7 @@ func (fc *Controller) CreateFile(c *gin.Context) { return } - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) res, err := fc.FileService.CreateFile(c, userId, &fileIn) if err != nil { @@ -36,7 +36,7 @@ func (fc *Controller) CreateFile(c *gin.Context) { func (fc *Controller) UpdateFile(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) var fileUpdate schemas.FileUpdate @@ -65,7 +65,7 @@ func (fc *Controller) GetFileByID(c *gin.Context) { func (fc *Controller) ListFiles(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) fquery := schemas.FileQuery{ PerPage: 500, @@ -90,7 +90,7 @@ func (fc *Controller) ListFiles(c *gin.Context) { func (fc *Controller) MakeDirectory(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) var payload schemas.MkDir if err := c.ShouldBindJSON(&payload); err != nil { @@ -118,7 +118,7 @@ func (fc *Controller) CopyFile(c *gin.Context) { func (fc *Controller) MoveFiles(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) var payload schemas.FileOperation if err := c.ShouldBindJSON(&payload); err != nil { @@ -136,7 +136,7 @@ func (fc *Controller) MoveFiles(c *gin.Context) { func (fc *Controller) DeleteFiles(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) var payload schemas.DeleteOperation if err := c.ShouldBindJSON(&payload); err != nil { @@ -164,7 +164,7 @@ func (fc *Controller) DeleteFileParts(c *gin.Context) { } func (fc *Controller) MoveDirectory(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) var payload schemas.DirMove if err := c.ShouldBindJSON(&payload); err != nil { @@ -181,7 +181,7 @@ func (fc *Controller) MoveDirectory(c *gin.Context) { } func (fc *Controller) GetCategoryStats(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) res, err := fc.FileService.GetCategoryStats(userId) if err != nil { @@ -193,5 +193,9 @@ func (fc *Controller) GetCategoryStats(c *gin.Context) { } func (fc *Controller) GetFileStream(c *gin.Context) { - fc.FileService.GetFileStream(c) + fc.FileService.GetFileStream(c, false) +} + +func (fc *Controller) GetFileDownload(c *gin.Context) { + fc.FileService.GetFileStream(c, true) } diff --git a/pkg/controller/upload.go b/pkg/controller/upload.go index ecd0172..403f65e 100644 --- a/pkg/controller/upload.go +++ b/pkg/controller/upload.go @@ -4,8 +4,8 @@ import ( "net/http" "strconv" + "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/pkg/httputil" - "github.com/divyam234/teldrive/pkg/services" "github.com/gin-gonic/gin" ) @@ -40,7 +40,7 @@ func (uc *Controller) UploadFile(c *gin.Context) { } func (uc *Controller) UploadStats(c *gin.Context) { - userId, _ := services.GetUserAuth(c) + userId, _ := auth.GetUser(c) days := 7 diff --git a/pkg/cron/cron.go b/pkg/cron/cron.go index a1ac86b..ab86c06 100644 --- a/pkg/cron/cron.go +++ b/pkg/cron/cron.go @@ -6,9 +6,9 @@ import ( "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/logging" + "github.com/divyam234/teldrive/internal/tgc" "github.com/divyam234/teldrive/pkg/models" "github.com/divyam234/teldrive/pkg/schemas" - "github.com/divyam234/teldrive/pkg/services" "github.com/go-co-op/gocron" "github.com/jackc/pgx/v5/pgtype" "go.uber.org/zap" @@ -87,7 +87,8 @@ func (c *CronService) CleanFiles(ctx context.Context) { } } - err := services.DeleteTGMessages(ctx, &c.cnf.TG, row.Session, row.ChannelId, row.UserId, ids) + client, _ := tgc.AuthClient(ctx, &c.cnf.TG, row.Session) + err := tgc.DeleteMessages(ctx, client, row.ChannelId, ids) if err != nil { c.logger.Errorw("failed to delete messages", err) @@ -122,7 +123,9 @@ func (c *CronService) CleanUploads(ctx context.Context) { for _, result := range upResults { if result.Session != "" && len(result.Parts) > 0 { - err := services.DeleteTGMessages(ctx, &c.cnf.TG, result.Session, result.ChannelId, result.UserId, result.Parts) + client, _ := tgc.AuthClient(ctx, &c.cnf.TG, result.Session) + + err := tgc.DeleteMessages(ctx, client, result.ChannelId, result.Parts) if err != nil { c.logger.Errorw("failed to delete messages", err) return diff --git a/pkg/services/common.go b/pkg/services/common.go index 02a9cea..05b21c6 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -1,214 +1,21 @@ package services import ( - "bytes" "context" - "crypto/rand" - "encoding/binary" "fmt" - "io" - "math" - "sort" - "strconv" - "sync" "github.com/divyam234/teldrive/internal/cache" - "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/crypt" - "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/tgc" "github.com/divyam234/teldrive/pkg/models" "github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/types" - "github.com/gin-gonic/gin" - "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/pkg/errors" - "github.com/thoas/go-funk" - "golang.org/x/sync/errgroup" "gorm.io/gorm" ) -type buffer struct { - Buf []byte -} - -func (b *buffer) long() (int64, error) { - v, err := b.uint64() - if err != nil { - return 0, err - } - return int64(v), nil -} -func (b *buffer) uint64() (uint64, error) { - const size = 8 - if len(b.Buf) < size { - return 0, io.ErrUnexpectedEOF - } - v := binary.LittleEndian.Uint64(b.Buf) - b.Buf = b.Buf[size:] - return v, nil -} - -func randInt64() (int64, error) { - var buf [8]byte - if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil { - return 0, err - } - b := &buffer{Buf: buf[:]} - return b.long() -} - -type batchResult struct { - Index int - Messages *tg.MessagesChannelMessages -} - -func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) { - - req := &tg.UploadGetFileRequest{ - Offset: offset, - Limit: int(limit), - Location: location, - } - - r, err := tgClient.API().UploadGetFile(ctx, req) - - if err != nil { - return nil, err - } - - switch result := r.(type) { - case *tg.UploadFile: - return result.Bytes, nil - default: - return nil, fmt.Errorf("unexpected type %T", r) - } -} - -func iterContent(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass) (*bytes.Buffer, error) { - offset := int64(0) - limit := int64(1024 * 1024) - buff := &bytes.Buffer{} - for { - r, err := getChunk(ctx, tgClient, location, offset, limit) - if err != nil { - return buff, err - } - if len(r) == 0 { - break - } - buff.Write(r) - offset += int64(limit) - } - return buff, nil -} - -func GetUserAuth(c *gin.Context) (int64, string) { - val, _ := c.Get("jwtUser") - jwtUser := val.(*types.JWTClaims) - userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64) - return userId, jwtUser.TgSession -} - -func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) { - client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...) - var user *tg.User - err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error { - user, _ = client.Self(ctx) - return nil - }) - - if err != nil { - return nil, err - } - return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil -} - -func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, index int, - results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) { - - defer wg.Done() - - ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass { - return &tg.InputMessageID{ID: int(part.ID)} - }).([]tg.InputMessageClass) - - messageRequest := tg.ChannelsGetMessagesRequest{ - Channel: channel, - ID: ids, - } - - res, err := client.API().ChannelsGetMessages(ctx, &messageRequest) - if err != nil { - errors <- err - return - } - - messages, ok := res.(*tg.MessagesChannelMessages) - - if !ok { - errors <- fmt.Errorf("unexpected response type: %T", res) - return - } - - results <- batchResult{Index: index, Messages: messages} -} - -func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) { - - channel, err := GetChannelById(ctx, client, channelId, userID) - - if err != nil { - return nil, err - } - - var wg sync.WaitGroup - - batchSize := 200 - - batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize))) - - results := make(chan batchResult, batchCount) - - errors := make(chan error, batchCount) - - for i := range batchCount { - wg.Add(1) - splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))] - go getTGMessagesBatch(ctx, client, channel, splitParts, i, results, errors, &wg) - } - - wg.Wait() - close(results) - close(errors) - - for err := range errors { - if err != nil { - return nil, err - } - } - - batchResults := []batchResult{} - - for result := range results { - batchResults = append(batchResults, result) - } - - sort.Slice(batchResults, func(i, j int) bool { - return batchResults[i].Index < batchResults[j].Index - }) - - allMessages := []tg.MessageClass{} - - for _, result := range batchResults { - allMessages = append(allMessages, result.Messages.GetMessages()...) - } - - return allMessages, nil -} - -func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) { +func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) { cache := cache.FromContext(ctx) parts := []types.Part{} @@ -220,7 +27,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu return parts, nil } - messages, err := getTGMessages(ctx, client, file.Parts, file.ChannelID, userID) + ids := []int{} + for _, part := range file.Parts { + ids = append(ids, int(part.ID)) + } + messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID) if err != nil { return nil, err @@ -230,12 +41,11 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu item := message.(*tg.Message) media := item.Media.(*tg.MessageMediaDocument) document := media.Document.(*tg.Document) - location := document.AsInputDocumentFileLocation() part := types.Part{ - Location: location, - Size: document.Size, - Salt: file.Parts[i].Salt, + ID: file.Parts[i].ID, + Size: document.Size, + Salt: file.Parts[i].Salt, } if file.Encrypted { part.DecryptedSize, _ = crypt.DecryptedSize(document.Size) @@ -246,27 +56,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu return parts, nil } -func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) { - - channel := &tg.InputChannel{} - inputChannel := &tg.InputChannel{ - ChannelID: channelId, - } - channels, err := client.API().ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel}) - - if err != nil { - return nil, err - } - - if len(channels.GetChats()) == 0 { - return nil, errors.New("no channels found") - } - - channel = channels.GetChats()[0].(*tg.Channel).AsInput() - return channel, nil -} - -func GetDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) { +func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) { cache := cache.FromContext(ctx) var channelId int64 key := fmt.Sprintf("users:channel:%d", userID) @@ -332,37 +122,3 @@ func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Ses return &session, nil } - -func DeleteTGMessages(ctx context.Context, cnf *config.TGConfig, session string, channelId, userId int64, ids []int) error { - - client, _ := tgc.AuthClient(ctx, cnf, session) - - err := tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error { - channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10)) - - if err != nil { - return err - } - - batchSize := 100 - - batchCount := int(math.Ceil(float64(len(ids)) / float64(batchSize))) - - g, _ := errgroup.WithContext(ctx) - - g.SetLimit(8) - - for i := 0; i < batchCount; i++ { - start := i * batchSize - end := min((i+1)*batchSize, len(ids)) - batchIds := ids[start:end] - g.Go(func() error { - messageDeleteRequest := tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: batchIds} - _, err = client.API().ChannelsDeleteMessages(ctx, &messageDeleteRequest) - return err - }) - } - return g.Wait() - }) - return err -} diff --git a/pkg/services/file.go b/pkg/services/file.go index 05760d3..9171380 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -2,7 +2,9 @@ package services import ( "context" + "crypto/rand" "encoding/base64" + "encoding/binary" "fmt" "io" "mime" @@ -12,6 +14,7 @@ import ( "time" "github.com/WinterYukky/gorm-extra-clause-plugin/exclause" + "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/category" "github.com/divyam234/teldrive/internal/config" @@ -35,6 +38,36 @@ import ( "gorm.io/gorm/clause" ) +type buffer struct { + Buf []byte +} + +func (b *buffer) long() (int64, error) { + v, err := b.uint64() + if err != nil { + return 0, err + } + return int64(v), nil +} +func (b *buffer) uint64() (uint64, error) { + const size = 8 + if len(b.Buf) < size { + return 0, io.ErrUnexpectedEOF + } + v := binary.LittleEndian.Uint64(b.Buf) + b.Buf = b.Buf[size:] + return v, nil +} + +func randInt64() (int64, error) { + var buf [8]byte + if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil { + return 0, err + } + b := &buffer{Buf: buf[:]} + return b.long() +} + type FileService struct { db *gorm.DB cnf *config.TGConfig @@ -75,7 +108,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas. channelId := fileIn.ChannelID if fileIn.ChannelID == 0 { var err error - channelId, err = GetDefaultChannel(c, fs.db, userId) + channelId, err = getDefaultChannel(c, fs.db, userId) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusNotFound} } @@ -332,7 +365,9 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess return nil, &types.AppError{Error: err} } - userId, session := GetUserAuth(c) + _, session := auth.GetUser(c) + + client, _ := tgc.AuthClient(c, fs.cnf, session) ids := []int{} @@ -340,7 +375,7 @@ func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Mess ids = append(ids, int(part.ID)) } - err := DeleteTGMessages(c, fs.cnf, session, *file.ChannelID, userId, ids) + err := tgc.DeleteMessages(c, client, *file.ChannelID, ids) if err != nil { return nil, &types.AppError{Error: err} @@ -380,7 +415,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr return nil, &types.AppError{Error: err, Code: http.StatusBadRequest} } - userId, session := GetUserAuth(c) + userId, session := auth.GetUser(c) client, _ := tgc.AuthClient(c, fs.cnf, session) @@ -394,20 +429,24 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr newIds := []schemas.Part{} - channelId, err := GetDefaultChannel(c, fs.db, userId) + channelId, err := getDefaultChannel(c, fs.db, userId) if err != nil { return nil, &types.AppError{Error: err} } err = tgc.RunWithAuth(c, client, "", func(ctx context.Context) error { - user := strconv.FormatInt(userId, 10) - messages, err := getTGMessages(c, client, file.Parts, file.ChannelID, user) + ids := []int{} + + for _, part := range file.Parts { + ids = append(ids, int(part.ID)) + } + messages, err := tgc.GetMessages(c, client.API(), ids, file.ChannelID) if err != nil { return err } - channel, err := GetChannelById(ctx, client, channelId, user) + channel, err := tgc.GetChannelById(ctx, client.API(), channelId) if err != nil { return err @@ -482,7 +521,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr return mapper.ToFileOut(dbFile), nil } -func (fs *FileService) GetFileStream(c *gin.Context) { +func (fs *FileService) GetFileStream(c *gin.Context, download bool) { w := c.Writer @@ -585,7 +624,7 @@ func (fs *FileService) GetFileStream(c *gin.Context) { disposition := "inline" - if c.Query("d") == "1" { + if download { disposition = "attachment" } @@ -601,14 +640,15 @@ func (fs *FileService) GetFileStream(c *gin.Context) { } var ( - channelUser string - lr io.ReadCloser + channelUser string + lr io.ReadCloser + client *tgc.Client + multiThreads int ) - var client *tgc.Client + multiThreads = fs.cnf.Stream.MultiThreads if fs.cnf.DisableStreamBots || len(tokens) == 0 { - client, err = fs.worker.UserWorker(session.Session, session.UserId) if err != nil { logger.Error("file stream", zap.Error(err)) @@ -616,47 +656,36 @@ func (fs *FileService) GetFileStream(c *gin.Context) { return } channelUser = strconv.FormatInt(session.UserId, 10) - - logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "user", channelUser, "start", start, - "end", end, "fileSize", file.Size) + multiThreads = 0 } else { - var index int limit := min(len(tokens), fs.cnf.BgBotsLimit) fs.worker.Set(tokens[:limit], file.ChannelID) - - client, index, err = fs.worker.Next(file.ChannelID) - + client, _, err = fs.worker.Next(file.ChannelID) if err != nil { logger.Error("file stream", zap.Error(err)) http.Error(w, err.Error(), http.StatusInternalServerError) return } - channelUser = strings.Split(tokens[index], ":")[0] - logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "botNo", index, "start", start, - "end", end, "fileSize", file.Size) } if r.Method != "HEAD" { - parts, err := getParts(c, client.Tg, file, channelUser) + parts, err := getParts(c, client.Tg.API(), file, channelUser) if err != nil { logger.Error("file stream", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - tgClient := client.Tg.API() - - if fs.cnf.Stream.UsePooling { - tgClient = client.Pool.Default(c) + if download { + multiThreads = 0 } - if file.Encrypted { - lr, err = reader.NewDecryptedReader(c, tgClient, parts, start, end, fs.cnf) + lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker) } else { - lr, err = reader.NewLinearReader(c, tgClient, parts, start, end, fs.cnf) + lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, fs.cnf, multiThreads, client, fs.worker) } if err != nil { @@ -669,7 +698,10 @@ func (fs *FileService) GetFileStream(c *gin.Context) { return } - io.CopyN(w, lr, contentLength) + _, err = io.CopyN(w, lr, contentLength) + if err != nil { + lr.Close() + } } } func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB { diff --git a/pkg/services/upload.go b/pkg/services/upload.go index 4460c21..57a3003 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/logging" @@ -116,7 +117,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty Code: http.StatusBadRequest} } - userId, session := GetUserAuth(c) + userId, session := auth.GetUser(c) uploadId := c.Param("id") @@ -127,7 +128,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty defer fileStream.Close() if uploadQuery.ChannelID == 0 { - channelId, err = GetDefaultChannel(c, us.db, userId) + channelId, err = getDefaultChannel(c, us.db, userId) if err != nil { return nil, &types.AppError{Error: err} } @@ -174,7 +175,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error { - channel, err := GetChannelById(ctx, client, channelId, channelUser) + channel, err := tgc.GetChannelById(ctx, client.API(), channelId) if err != nil { return err diff --git a/pkg/services/user.go b/pkg/services/user.go index ec918e2..a889cb7 100644 --- a/pkg/services/user.go +++ b/pkg/services/user.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/kv" @@ -41,7 +42,7 @@ func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV) *UserService { return &UserService{db: db, cnf: cnf, kv: kv} } func (us *UserService) GetProfilePhoto(c *gin.Context) { - _, session := GetUserAuth(c) + _, session := auth.GetUser(c) client, err := tgc.AuthClient(c, &us.cnf.TG, session) @@ -64,7 +65,7 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) { return errors.New("profile not found") } location := &tg.InputPeerPhotoFileLocation{Big: false, Peer: peer, PhotoID: photo.PhotoID} - buff, err := iterContent(c, client, location) + buff, err := tgc.GetMediaContent(c, client.API(), location) if err != nil { return err } @@ -83,13 +84,13 @@ func (us *UserService) GetProfilePhoto(c *gin.Context) { } func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.AppError) { - userID, _ := GetUserAuth(c) + userID, _ := auth.GetUser(c) var ( channelId int64 err error ) - channelId, _ = GetDefaultChannel(c, us.db, userID) + channelId, _ = getDefaultChannel(c, us.db, userID) tokens, err := getBotsToken(c, us.db, userID, channelId) @@ -103,7 +104,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A cache := cache.FromContext(c) - userId, _ := GetUserAuth(c) + userId, _ := auth.GetUser(c) var payload schemas.Channel @@ -130,7 +131,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A } func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *types.AppError) { - userId, userSession := GetUserAuth(c) + userId, userSession := auth.GetUser(c) client, _ := tgc.AuthClient(c, &us.cnf.TG, userSession) @@ -185,7 +186,7 @@ func (us *UserService) ListSessions(c *gin.Context) ([]schemas.SessionOut, *type func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.AppError) { - userId, _ := GetUserAuth(c) + userId, _ := auth.GetUser(c) session := &models.Session{} @@ -209,7 +210,7 @@ func (us *UserService) RemoveSession(c *gin.Context) (*schemas.Message, *types.A } func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppError) { - _, session := GetUserAuth(c) + _, session := auth.GetUser(c) client, _ := tgc.AuthClient(c, &us.cnf.TG, session) channels := make(map[int64]*schemas.Channel) @@ -236,7 +237,7 @@ func (us *UserService) ListChannels(c *gin.Context) (interface{}, *types.AppErro } func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppError) { - userId, session := GetUserAuth(c) + userId, session := auth.GetUser(c) client, _ := tgc.AuthClient(c, &us.cnf.TG, session) var botsTokens []string @@ -249,7 +250,7 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro return &schemas.Message{Message: "no bots to add"}, nil } - channelId, err := GetDefaultChannel(c, us.db, userId) + channelId, err := getDefaultChannel(c, us.db, userId) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} @@ -263,9 +264,9 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE cache := cache.FromContext(c) - userID, _ := GetUserAuth(c) + userID, _ := auth.GetUser(c) - channelId, err := GetDefaultChannel(c, us.db, userID) + channelId, err := getDefaultChannel(c, us.db, userID) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} @@ -294,7 +295,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error { - channel, err := GetChannelById(ctx, client, channelId, strconv.FormatInt(userId, 10)) + channel, err := tgc.GetChannelById(ctx, client.API(), channelId) if err != nil { logger.Error("error", zap.Error(err)) @@ -309,7 +310,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI waitChan <- struct{}{} wg.Add(1) go func(t string) { - info, err := getBotInfo(c, us.kv, &us.cnf.TG, t) + info, err := tgc.GetBotInfo(c, client, t) if err != nil { return } diff --git a/pkg/types/types.go b/pkg/types/types.go index 94e7cfc..5d9bb34 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -3,7 +3,6 @@ package types import ( "github.com/go-jose/go-jose/v3/jwt" "github.com/gotd/td/session" - "github.com/gotd/td/tg" ) type AppError struct { @@ -12,10 +11,10 @@ type AppError struct { } type Part struct { - Location *tg.InputDocumentFileLocation DecryptedSize int64 Size int64 Salt string + ID int64 } type JWTClaims struct {