From c7149701f1bead9817205a88e7592dfa39ab397c Mon Sep 17 00:00:00 2001 From: divyam234 <47589864+divyam234@users.noreply.github.com> Date: Sat, 10 Aug 2024 00:07:56 +0530 Subject: [PATCH] refactor: Update dependencies and refactor multireader and workers --- cmd/run.go | 16 +-- internal/config/config.go | 5 +- internal/reader/decrypted_reader.go | 34 ++--- internal/reader/reader.go | 35 ++++-- internal/reader/tg_multi_reader.go | 32 ++--- internal/reader/tg_multi_reader_test.go | 4 +- internal/tgc/helpers.go | 65 +++++----- internal/tgc/workers.go | 161 +++++++++++------------- pkg/services/common.go | 7 +- pkg/services/file.go | 131 +++++++++++-------- pkg/services/file_test.go | 2 +- pkg/services/upload.go | 4 +- 12 files changed, 257 insertions(+), 239 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index b91c8a5..c54da37 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -99,8 +99,8 @@ func NewRun() *cobra.Command { runCmd.Flags().StringVar(&config.TG.SystemLangCode, "tg-system-lang-code", "en-US", "System language code") runCmd.Flags().StringVar(&config.TG.LangPack, "tg-lang-pack", "webk", "Language pack") runCmd.Flags().StringVar(&config.TG.Proxy, "tg-proxy", "", "HTTP OR SOCKS5 proxy URL") - runCmd.Flags().IntVar(&config.TG.BgBotsLimit, "tg-bg-bots-limit", 5, "Background bots limit") - runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable stream bots") + runCmd.Flags().BoolVar(&config.TG.DisableBgBots, "tg-disable-bg-bots", false, "Disable Background bots") + runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable Stream bots") runCmd.Flags().BoolVar(&config.TG.EnableLogging, "tg-enable-logging", false, "Enable telegram client logging") runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key") runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads") @@ -108,12 +108,11 @@ func NewRun() *cobra.Command { runCmd.Flags().Int64Var(&config.TG.PoolSize, "tg-pool-size", 8, "Telegram Session pool size") duration.DurationVar(runCmd.Flags(), &config.TG.ReconnectTimeout, "tg-reconnect-timeout", 5*time.Minute, "Reconnect Timeout") duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration") + duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 4*time.Hour, "Interval for checking Idle background bots") runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") - runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset") - duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout") - duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots") - duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots") + runCmd.Flags().IntVar(&config.TG.Stream.BotsLimit, "tg-stream-bots-limit", 5, "Stream bots limit") + duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 20*time.Second, "Chunk Fetch Timeout") runCmd.MarkFlagRequired("tg-app-id") runCmd.MarkFlagRequired("tg-app-hash") runCmd.MarkFlagRequired("db-data-source") @@ -147,13 +146,14 @@ func runApplication(conf *config.Config) { return cacher }), fx.Supply(logging.DefaultLogger().Desugar()), + fx.Supply(logging.DefaultLogger()), fx.NopLogger, fx.StopTimeout(conf.Server.GracefulShutdown+time.Second), fx.Provide( database.NewDatabase, kv.NewBoltKV, - tgc.NewStreamWorker(ctx), - tgc.NewUploadWorker, + tgc.NewBotWorker, + tgc.NewStreamWorker, services.NewAuthService, services.NewFileService, services.NewUploadService, diff --git a/internal/config/config.go b/internal/config/config.go index 6ba3a9a..36cde70 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,9 +47,8 @@ type TGConfig struct { SystemLangCode string LangPack string SessionFile string - BgBotsLimit int + DisableBgBots bool DisableStreamBots bool - BgBotsTimeout time.Duration BgBotsCheckInterval time.Duration Proxy string ReconnectTimeout time.Duration @@ -62,7 +61,7 @@ type TGConfig struct { Retention time.Duration } Stream struct { - BotsOffset int + BotsLimit int MultiThreads int Buffers int ChunkTimeout time.Duration diff --git a/internal/reader/decrypted_reader.go b/internal/reader/decrypted_reader.go index 6a2ce2c..0ea2f16 100644 --- a/internal/reader/decrypted_reader.go +++ b/internal/reader/decrypted_reader.go @@ -2,53 +2,55 @@ package reader import ( "context" + "fmt" "io" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/internal/tgc" + "github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/types" + "github.com/gotd/td/tg" ) type decrpytedReader struct { ctx context.Context + file *schemas.FileOutFull parts []types.Part ranges []types.Range pos int reader io.ReadCloser limit int64 config *config.TGConfig - channelId int64 worker *tgc.StreamWorker - client *tgc.Client - fileId string + client *tg.Client concurrency int cache cache.Cacher } func NewDecryptedReader( ctx context.Context, - fileId string, + client *tg.Client, + worker *tgc.StreamWorker, + cache cache.Cacher, + file *schemas.FileOutFull, parts []types.Part, - start, end int64, - channelId int64, + start, + end int64, config *config.TGConfig, concurrency int, - client *tgc.Client, - worker *tgc.StreamWorker, - cache cache.Cacher) (*decrpytedReader, error) { +) (*decrpytedReader, error) { r := &decrpytedReader{ ctx: ctx, parts: parts, + file: file, limit: end - start + 1, ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize), config: config, client: client, worker: worker, - channelId: channelId, - fileId: fileId, concurrency: concurrency, cache: cache, } @@ -113,14 +115,16 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { if underlyingLimit >= 0 { end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1) } + partID := r.parts[r.ranges[r.pos].PartNo].ID + chunkSrc := &chunkSource{ - channelID: r.channelId, - worker: r.worker, - fileID: r.fileId, - partID: r.parts[r.ranges[r.pos].PartNo].ID, + channelID: r.file.ChannelID, + partID: partID, client: r.client, concurrency: r.concurrency, cache: r.cache, + key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID), + worker: r.worker, } if r.concurrency < 2 { return newTGReader(r.ctx, underlyingOffset, end, chunkSrc) diff --git a/internal/reader/reader.go b/internal/reader/reader.go index d72b4c6..fe76a01 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -2,12 +2,15 @@ package reader import ( "context" + "fmt" "io" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/tgc" + "github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/types" + "github.com/gotd/td/tg" ) func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { @@ -41,34 +44,40 @@ func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { type LinearReader struct { ctx context.Context + file *schemas.FileOutFull parts []types.Part ranges []types.Range pos int reader io.ReadCloser limit int64 config *config.TGConfig - channelID int64 worker *tgc.StreamWorker - client *tgc.Client - fileID string + client *tg.Client concurrency int cache cache.Cacher } -func NewLinearReader(ctx context.Context, fileID string, parts []types.Part, start, end int64, - channelID int64, config *config.TGConfig, concurrency int, client *tgc.Client, - worker *tgc.StreamWorker, cache cache.Cacher) (io.ReadCloser, error) { +func NewLinearReader(ctx context.Context, + client *tg.Client, + worker *tgc.StreamWorker, + cache cache.Cacher, + file *schemas.FileOutFull, + parts []types.Part, + start, + end int64, + config *config.TGConfig, + concurrency int, +) (io.ReadCloser, error) { r := &LinearReader{ ctx: ctx, parts: parts, + file: file, limit: end - start + 1, ranges: calculatePartByteRanges(start, end, parts[0].Size), config: config, client: client, worker: worker, - channelID: channelID, - fileID: fileID, concurrency: concurrency, cache: cache, } @@ -108,14 +117,16 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) { start := r.ranges[r.pos].Start end := r.ranges[r.pos].End + partID := r.parts[r.ranges[r.pos].PartNo].ID + chunkSrc := &chunkSource{ - channelID: r.channelID, - worker: r.worker, - fileID: r.fileID, - partID: r.parts[r.ranges[r.pos].PartNo].ID, + channelID: r.file.ChannelID, + partID: partID, client: r.client, concurrency: r.concurrency, cache: r.cache, + key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID), + worker: r.worker, } if r.concurrency < 2 { diff --git a/internal/reader/tg_multi_reader.go b/internal/reader/tg_multi_reader.go index afa9df1..92140f4 100644 --- a/internal/reader/tg_multi_reader.go +++ b/internal/reader/tg_multi_reader.go @@ -26,12 +26,12 @@ type ChunkSource interface { type chunkSource struct { channelID int64 - worker *tgc.StreamWorker - fileID string partID int64 concurrency int - client *tgc.Client + client *tg.Client + key string cache cache.Cacher + worker *tgc.StreamWorker } func (c *chunkSource) ChunkSize(start, end int64) int64 { @@ -42,27 +42,31 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b var ( location *tg.InputDocumentFileLocation err error - client *tgc.Client + client *tg.Client ) + err = c.cache.Get(c.key, location) + client = c.client - defer func() { - if c.concurrency > 0 && client != nil { - defer c.worker.Release(client) - } - }() - if c.concurrency > 0 { - client, _, _ = c.worker.Next(c.channelID) + tc, err := c.worker.Next(c.channelID) + if err != nil { + return nil, err + } + client = tc.Tg.API() + } - location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID) if err != nil { - return nil, err + location, err = tgc.GetLocation(ctx, client, c.channelID, c.partID) + if err != nil { + return nil, err + } + c.cache.Set(c.key, location, 30*time.Minute) } - return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit) + return tgc.GetChunk(ctx, client, location, offset, limit) } diff --git a/internal/reader/tg_multi_reader_test.go b/internal/reader/tg_multi_reader_test.go index ae7b000..e80c7e4 100644 --- a/internal/reader/tg_multi_reader_test.go +++ b/internal/reader/tg_multi_reader_test.go @@ -46,11 +46,11 @@ type TestSuite struct { func (suite *TestSuite) SetupTest() { suite.config = &config.TGConfig{Stream: struct { - BotsOffset int + BotsLimit int MultiThreads int Buffers int ChunkTimeout time.Duration - }{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}} + }{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}} } func (suite *TestSuite) TestFullRead() { diff --git a/internal/tgc/helpers.go b/internal/tgc/helpers.go index c50b441..8bf7f9e 100644 --- a/internal/tgc/helpers.go +++ b/internal/tgc/helpers.go @@ -8,9 +8,7 @@ import ( "math" "runtime" "sync" - "time" - "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/pkg/types" @@ -199,44 +197,39 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil } -func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) { +func GetLocation(ctx context.Context, client *tg.Client, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) { - key := fmt.Sprintf("files:location:%s:%s:%d", client.UserID, fileId, partId) - - err = cache.Get(key, location) + channel, err := GetChannelById(ctx, client, channelId) if err != nil { - channel, err := GetChannelById(ctx, client.Tg.API(), channelId) - - if err != nil { - return nil, err - } - messageRequest := tg.ChannelsGetMessagesRequest{ - Channel: channel, - ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}}, - } - - res, err := client.Tg.API().ChannelsGetMessages(ctx, &messageRequest) - if err != nil { - return nil, err - } - - messages, _ := res.(*tg.MessagesChannelMessages) - - if len(messages.Messages) == 0 { - return nil, errors.New("no messages found") - } - - switch item := messages.Messages[0].(type) { - case *tg.MessageEmpty: - return nil, errors.New("no messages found") - case *tg.Message: - media := item.Media.(*tg.MessageMediaDocument) - document := media.Document.(*tg.Document) - location = document.AsInputDocumentFileLocation() - cache.Set(key, location, 30*time.Minute) - } + return nil, err } + messageRequest := tg.ChannelsGetMessagesRequest{ + Channel: channel, + ID: []tg.InputMessageClass{&tg.InputMessageID{ID: int(partId)}}, + } + + res, err := client.ChannelsGetMessages(ctx, &messageRequest) + if err != nil { + return nil, err + } + + messages, _ := res.(*tg.MessagesChannelMessages) + + if len(messages.Messages) == 0 { + return nil, errors.New("no messages found") + } + + switch item := messages.Messages[0].(type) { + case *tg.MessageEmpty: + return nil, errors.New("no messages found") + case *tg.Message: + media := item.Media.(*tg.MessageMediaDocument) + document := media.Document.(*tg.Document) + location = document.AsInputDocumentFileLocation() + + } + return location, nil } diff --git a/internal/tgc/workers.go b/internal/tgc/workers.go index 8bd23ad..ee2c7bf 100644 --- a/internal/tgc/workers.go +++ b/internal/tgc/workers.go @@ -2,39 +2,37 @@ package tgc import ( "context" - "strconv" "strings" "sync" "time" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/kv" - "github.com/divyam234/teldrive/internal/logging" "github.com/gotd/td/telegram" "go.uber.org/zap" ) -type UploadWorker struct { +type BotWorker struct { mu sync.RWMutex bots map[int64][]string currIdx map[int64]int } -func NewUploadWorker() *UploadWorker { - return &UploadWorker{ +func NewBotWorker() *BotWorker { + return &BotWorker{ bots: make(map[int64][]string), currIdx: make(map[int64]int), } } -func (w *UploadWorker) Set(bots []string, channelID int64) { +func (w *BotWorker) Set(bots []string, channelID int64) { w.mu.Lock() defer w.mu.Unlock() w.bots[channelID] = bots w.currIdx[channelID] = 0 } -func (w *UploadWorker) Next(channelID int64) (string, int) { +func (w *BotWorker) Next(channelID int64) (string, int) { w.mu.RLock() defer w.mu.RUnlock() bots := w.bots[channelID] @@ -43,40 +41,48 @@ func (w *UploadWorker) Next(channelID int64) (string, int) { return bots[index], index } +type ClientStatus int + +const ( + StatusIdle ClientStatus = iota + StatusBusy +) + type Client struct { - Tg *telegram.Client - Stop StopFunc - Status string - UserID string - LastUsed time.Time - Connections int + Tg *telegram.Client + Stop StopFunc + Status ClientStatus + UserID string } type StreamWorker struct { - mu sync.RWMutex - clients map[string]*Client - currIdx map[int64]int - channelBots map[int64][]string - cnf *config.TGConfig - kv kv.KV - ctx context.Context - logger *zap.SugaredLogger + mu sync.RWMutex + clients map[string]*Client + currIdx map[int64]int + channelBots map[int64][]string + cnf *config.TGConfig + kv kv.KV + ctx context.Context + logger *zap.SugaredLogger + activeStreams int + cancel context.CancelFunc } -func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker { - return func(cnf *config.Config, kv kv.KV) *StreamWorker { - worker := &StreamWorker{ - cnf: &cnf.TG, - kv: kv, - ctx: ctx, - clients: make(map[string]*Client), - currIdx: make(map[int64]int), - channelBots: make(map[int64][]string), - logger: logging.FromContext(ctx), - } - go worker.startIdleClientMonitor() - return worker +func NewStreamWorker(cnf *config.Config, kv kv.KV, logger *zap.SugaredLogger) *StreamWorker { + ctx, cancel := context.WithCancel(context.Background()) + worker := &StreamWorker{ + cnf: &cnf.TG, + kv: kv, + ctx: ctx, + clients: make(map[string]*Client), + currIdx: make(map[int64]int), + channelBots: make(map[int64][]string), + logger: logger, + cancel: cancel, } + go worker.startIdleClientMonitor() + return worker + } func (w *StreamWorker) Set(bots []string, channelID int64) { @@ -86,7 +92,7 @@ func (w *StreamWorker) Set(bots []string, channelID int64) { w.currIdx[channelID] = 0 } -func (w *StreamWorker) Next(channelID int64) (*Client, int, error) { +func (w *StreamWorker) Next(channelID int64) (*Client, error) { w.mu.Lock() defer w.mu.Unlock() @@ -97,75 +103,51 @@ func (w *StreamWorker) Next(channelID int64) (*Client, int, error) { client, err := w.getOrCreateClient(userID, token) if err != nil { - return nil, 0, err + return nil, err } w.currIdx[channelID] = (index + 1) % len(bots) - client.LastUsed = time.Now() - client.Connections++ - if client.Connections == 1 { - client.Status = "serving" - } - return client, index, nil + return client, nil +} + +func (w *StreamWorker) IncActiveStream() error { + w.mu.Lock() + defer w.mu.Unlock() + + w.activeStreams++ + return nil +} + +func (w *StreamWorker) DecActiveStreams() error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.activeStreams == 0 { + return nil + } + w.activeStreams-- + return nil } func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) { client, ok := w.clients[userID] - if !ok || (client.Status == "idle" && client.Stop == nil) { + if !ok || (client.Status == StatusIdle && client.Stop == nil) { middlewares := Middlewares(w.cnf, 5) tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...) - client = &Client{Tg: tgClient, Status: "idle", UserID: userID} + client = &Client{Tg: tgClient, Status: StatusIdle, UserID: userID} w.clients[userID] = client - stop, err := Connect(client.Tg, WithBotToken(token)) if err != nil { return nil, err } client.Stop = stop - w.logger.Debug("started bg client: ", client.UserID) + client.Status = StatusBusy + w.logger.Debug("started bg client: ", userID) } return client, nil } -func (w *StreamWorker) Release(client *Client) { - w.mu.Lock() - defer w.mu.Unlock() - client.Connections-- - if client.Connections == 0 { - client.Status = "running" - } -} - -func (w *StreamWorker) UserWorker(session string, userID int64) (*Client, error) { - w.mu.Lock() - defer w.mu.Unlock() - - id := strconv.FormatInt(userID, 10) - client, ok := w.clients[id] - if !ok || (client.Status == "idle" && client.Stop == nil) { - middlewares := Middlewares(w.cnf, 5) - tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) - client = &Client{Tg: tgClient, Status: "idle", UserID: id} - w.clients[id] = client - - stop, err := Connect(client.Tg, WithContext(w.ctx)) - if err != nil { - return nil, err - } - client.Stop = stop - w.logger.Debug("started bg client: ", client.UserID) - } - - client.LastUsed = time.Now() - client.Connections++ - if client.Connections == 1 { - client.Status = "serving" - } - - return client, nil -} - func (w *StreamWorker) startIdleClientMonitor() { ticker := time.NewTicker(w.cnf.BgBotsCheckInterval) defer ticker.Stop() @@ -183,15 +165,16 @@ func (w *StreamWorker) startIdleClientMonitor() { func (w *StreamWorker) checkIdleClients() { w.mu.Lock() defer w.mu.Unlock() - - for _, client := range w.clients { - if client.Status == "running" && time.Since(client.LastUsed) > w.cnf.BgBotsTimeout { - if client.Stop != nil { + if w.activeStreams == 0 { + for _, client := range w.clients { + if client.Status == StatusBusy && client.Stop != nil { client.Stop() client.Stop = nil - client.Status = "idle" + client.Tg = nil + client.Status = StatusIdle w.logger.Debug("stopped bg client: ", client.UserID) } } } + } diff --git a/pkg/services/common.go b/pkg/services/common.go index 5389b9d..c793e32 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -11,16 +11,17 @@ import ( "github.com/divyam234/teldrive/pkg/models" "github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/types" + "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/pkg/errors" "gorm.io/gorm" ) -func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *schemas.FileOutFull, userID string) ([]types.Part, error) { +func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *schemas.FileOutFull) ([]types.Part, error) { parts := []types.Part{} - key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID) + key := fmt.Sprintf("files:messages:%s", file.Id) err := cache.Get(key, &parts) @@ -32,7 +33,7 @@ func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file * for _, part := range file.Parts { ids = append(ids, int(part.ID)) } - messages, err := tgc.GetMessages(ctx, client, ids, file.ChannelID) + messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelID) if err != nil { return nil, err diff --git a/pkg/services/file.go b/pkg/services/file.go index 5f0773b..1bd0bd6 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -22,7 +22,7 @@ import ( "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/database" "github.com/divyam234/teldrive/internal/http_range" - "github.com/divyam234/teldrive/internal/logging" + "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/md5" "github.com/divyam234/teldrive/internal/reader" "github.com/divyam234/teldrive/internal/tgc" @@ -32,6 +32,7 @@ import ( "github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/types" "github.com/gin-gonic/gin" + "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "go.uber.org/zap" @@ -76,14 +77,24 @@ func randInt64() (int64, error) { } type FileService struct { - db *gorm.DB - cnf *config.Config - worker *tgc.StreamWorker - cache cache.Cacher + db *gorm.DB + cnf *config.Config + worker *tgc.StreamWorker + botWorker *tgc.BotWorker + cache cache.Cacher + kv kv.KV + logger *zap.SugaredLogger } -func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache cache.Cacher) *FileService { - return &FileService{db: db, cnf: cnf, worker: worker, cache: cache} +func NewFileService( + db *gorm.DB, + cnf *config.Config, + worker *tgc.StreamWorker, + botWorker *tgc.BotWorker, + kv kv.KV, + cache cache.Cacher, + logger *zap.SugaredLogger) *FileService { + return &FileService{db: db, cnf: cnf, worker: worker, botWorker: botWorker, cache: cache, kv: kv, logger: logger} } func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) { @@ -716,84 +727,96 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { tokens, err := getBotsToken(fs.db, fs.cache, session.UserId, file.ChannelID) - logger := logging.FromContext(c) if err != nil { - logger.Error("failed to get bots", zap.Error(err)) - http.Error(w, err.Error(), http.StatusInternalServerError) + fs.handleError(fmt.Errorf("failed to get bots: %w", err), w) return } var ( - channelUser string lr io.ReadCloser - client *tgc.Client + client *telegram.Client multiThreads int + token string ) multiThreads = fs.cnf.TG.Stream.MultiThreads - defer func() { - if client != nil { - fs.worker.Release(client) - } - }() - if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 { - client, err = fs.worker.UserWorker(session.Session, session.UserId) + client, err = tgc.AuthClient(c, &fs.cnf.TG, session.Session) if err != nil { - logger.Error(ErrorStreamAbandoned, err) - http.Error(w, err.Error(), http.StatusInternalServerError) + fs.handleError(err, w) return } - channelUser = strconv.FormatInt(session.UserId, 10) multiThreads = 0 - } else { - offset := fs.cnf.TG.Stream.BotsOffset - 1 - limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset) - fs.worker.Set(tokens[offset:limit], file.ChannelID) - client, _, err = fs.worker.Next(file.ChannelID) + } else if fs.cnf.TG.DisableBgBots && len(tokens) > 0 { + fs.botWorker.Set(tokens, file.ChannelID) + token, _ = fs.botWorker.Next(file.ChannelID) + client, err = tgc.BotClient(c, fs.kv, &fs.cnf.TG, token) if err != nil { - logger.Error(ErrorStreamAbandoned, err) - http.Error(w, err.Error(), http.StatusInternalServerError) + fs.handleError(err, w) + } + multiThreads = 0 + } else { + fs.worker.Set(tokens[0:fs.cnf.TG.Stream.BotsLimit], file.ChannelID) + c, err := fs.worker.Next(file.ChannelID) + if err != nil { + fs.handleError(err, w) return } + client = c.Tg + } + + if download { + multiThreads = 0 } if r.Method != "HEAD" { - parts, err := getParts(c, client.Tg.API(), fs.cache, file, channelUser) - if err != nil { - logger.Error(ErrorStreamAbandoned, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + handleStream := func() error { + parts, err := getParts(c, client, fs.cache, file) + if err != nil { + fs.handleError(err, w) + return nil + } + if file.Encrypted { + lr, err = reader.NewDecryptedReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads) + } else { + lr, err = reader.NewLinearReader(c, client.API(), fs.worker, fs.cache, file, parts, start, end, &fs.cnf.TG, multiThreads) + } - if download { - multiThreads = 0 + if err != nil { + fs.handleError(err, w) + return nil + } + if lr == nil { + fs.handleError(fmt.Errorf("failed to initialise reader"), w) + return nil + } + _, err = io.CopyN(w, lr, contentLength) + if err != nil { + lr.Close() + } + return nil } - if file.Encrypted { - lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache) + if fs.cnf.TG.DisableBgBots { + tgc.RunWithAuth(c, client, token, func(ctx context.Context) error { + return handleStream() + }) } else { - lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache) + fs.worker.IncActiveStream() + defer fs.worker.DecActiveStreams() + handleStream() } - if err != nil { - logger.Error(ErrorStreamAbandoned, err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if lr == nil { - http.Error(w, "failed to initialise reader", http.StatusInternalServerError) - return - } - - _, err = io.CopyN(w, lr, contentLength) - if err != nil { - lr.Close() - } } } +func (fs *FileService) handleError(err error, w http.ResponseWriter) { + fs.logger.Error(err) + http.Error(w, err.Error(), http.StatusInternalServerError) + +} + func getOrder(fquery *schemas.FileQuery) clause.OrderByColumn { sortColumn := utils.CamelToSnake(fquery.Sort) diff --git a/pkg/services/file_test.go b/pkg/services/file_test.go index ba6ae74..5625615 100644 --- a/pkg/services/file_test.go +++ b/pkg/services/file_test.go @@ -20,7 +20,7 @@ type FileServiceSuite struct { func (s *FileServiceSuite) SetupSuite() { s.db = database.NewTestDatabase(s.T(), false) - s.srv = NewFileService(s.db, nil, nil, nil) + s.srv = NewFileService(s.db, nil, nil, nil, nil, nil, nil) } func (s *FileServiceSuite) SetupTest() { diff --git a/pkg/services/upload.go b/pkg/services/upload.go index c8d3871..02101ac 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -39,13 +39,13 @@ const saltLength = 32 type UploadService struct { db *gorm.DB - worker *tgc.UploadWorker + worker *tgc.BotWorker cnf *config.TGConfig kv kv.KV cache cache.Cacher } -func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV, cache cache.Cacher) *UploadService { +func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.BotWorker, kv kv.KV, cache cache.Cacher) *UploadService { return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv, cache: cache} }