From 2c7db87b3f88ae79bc8b367981b069ad60a74597 Mon Sep 17 00:00:00 2001 From: divyam234 <47589864+divyam234@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:45:26 +0530 Subject: [PATCH] feat: option to close bg bots after a certain time --- cmd/run.go | 10 +- internal/config/config.go | 40 ++++---- internal/reader/decrypted_reader.go | 2 +- internal/reader/tg_reader.go | 6 ++ internal/tgc/workers.go | 138 +++++++++++++++++++--------- pkg/controller/file.go | 5 - pkg/services/file.go | 6 ++ 7 files changed, 133 insertions(+), 74 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 76a7385..857bacb 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -48,7 +48,7 @@ func NewRun() *cobra.Command { }, } - runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)") + runCmd.Flags().StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)") runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port") duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout") @@ -71,9 +71,9 @@ func NewRun() *cobra.Command { runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID") runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash") runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path") - runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting") - runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst") - runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate") + runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client") + runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client") + runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client") runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model", "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model") runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version") @@ -94,6 +94,8 @@ func NewRun() *cobra.Command { runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout") + duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots") + duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots") runCmd.MarkFlagRequired("tg-app-id") runCmd.MarkFlagRequired("tg-app-hash") runCmd.MarkFlagRequired("db-data-source") diff --git a/internal/config/config.go b/internal/config/config.go index 5f27ed5..205fd22 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,25 +18,27 @@ type ServerConfig struct { } type TGConfig struct { - AppId int - AppHash string - RateLimit bool - RateBurst int - Rate int - DeviceModel string - SystemVersion string - AppVersion string - LangCode string - SystemLangCode string - LangPack string - SessionFile string - BgBotsLimit int - DisableStreamBots bool - Proxy string - ReconnectTimeout time.Duration - PoolSize int64 - EnableLogging bool - Uploads struct { + AppId int + AppHash string + RateLimit bool + RateBurst int + Rate int + DeviceModel string + SystemVersion string + AppVersion string + LangCode string + SystemLangCode string + LangPack string + SessionFile string + BgBotsLimit int + DisableStreamBots bool + BgBotsTimeout time.Duration + BgBotsCheckInterval time.Duration + Proxy string + ReconnectTimeout time.Duration + PoolSize int64 + EnableLogging bool + Uploads struct { EncryptionKey string Threads int MaxRetries int diff --git a/internal/reader/decrypted_reader.go b/internal/reader/decrypted_reader.go index 2978b5b..effae79 100644 --- a/internal/reader/decrypted_reader.go +++ b/internal/reader/decrypted_reader.go @@ -112,7 +112,7 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, client: r.client, concurrency: r.concurrency} - return newTGReader(r.ctx, start, end, r.config, chunkSrc) + return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc) }, start, end-start+1) diff --git a/internal/reader/tg_reader.go b/internal/reader/tg_reader.go index 4437d2e..3d6569c 100644 --- a/internal/reader/tg_reader.go +++ b/internal/reader/tg_reader.go @@ -43,6 +43,12 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b client = c.client + defer func() { + if c.concurrency > 0 && client != nil { + defer c.worker.Release(client) + } + }() + if c.concurrency > 0 { client, _, _ = c.worker.Next(c.channelId) } diff --git a/internal/tgc/workers.go b/internal/tgc/workers.go index 5b42539..829b4a8 100644 --- a/internal/tgc/workers.go +++ b/internal/tgc/workers.go @@ -2,12 +2,16 @@ package tgc import ( "context" + "strconv" "strings" "sync" + "time" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/kv" + "github.com/divyam234/teldrive/internal/logging" "github.com/gotd/td/telegram" + "go.uber.org/zap" ) type UploadWorker struct { @@ -41,37 +45,32 @@ func NewUploadWorker() *UploadWorker { } type Client struct { - Tg *telegram.Client - Stop StopFunc - Status string - UserId string + Tg *telegram.Client + Stop StopFunc + Status string + UserId string + lastUsed time.Time + connections int } type StreamWorker struct { - mu sync.Mutex - bots map[int64][]string - clients map[int64][]*Client - currIdx map[int64]int - cnf *config.TGConfig - kv kv.KV - ctx context.Context + mu sync.Mutex + clients map[string]*Client + currIdx map[int64]int + channelBots map[int64][]string + cnf *config.TGConfig + kv kv.KV + ctx context.Context + logger *zap.SugaredLogger } func (w *StreamWorker) Set(bots []string, channelId int64) { + w.mu.Lock() defer w.mu.Unlock() - _, ok := w.bots[channelId] + _, ok := w.channelBots[channelId] if !ok { - w.bots = make(map[int64][]string) - w.clients = make(map[int64][]*Client) - w.currIdx = make(map[int64]int) - w.bots[channelId] = bots - for _, token := range bots { - middlewares := Middlewares(w.cnf, 5) - client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...) - c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]} - w.clients[channelId] = append(w.clients[channelId], c) - } + w.channelBots[channelId] = bots w.currIdx[channelId] = 0 } @@ -81,47 +80,96 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) { w.mu.Lock() defer w.mu.Unlock() index := w.currIdx[channelId] - nextClient := w.clients[channelId][index] - w.currIdx[channelId] = (index + 1) % len(w.clients[channelId]) - if nextClient.Status == "idle" { - stop, err := Connect(nextClient.Tg, WithBotToken(w.bots[channelId][index])) + token := w.channelBots[channelId][index] + userId := strings.Split(token, ":")[0] + client, ok := w.clients[userId] + if !ok || (client.Status == "idle" && client.Stop == nil) { + middlewares := Middlewares(w.cnf, 5) + tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...) + client = &Client{Tg: tgClient, Status: "idle", UserId: userId} + w.clients[userId] = client + stop, err := Connect(client.Tg, WithBotToken(token)) if err != nil { return nil, 0, err } - nextClient.Stop = stop - nextClient.Status = "running" + client.Stop = stop + w.logger.Debug("started bg client: ", client.UserId) + } + w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId]) + client.lastUsed = time.Now() + if client.connections == 0 { + client.Status = "serving" + } + client.connections++ + return client, index, nil +} + +func (w *StreamWorker) Release(client *Client) { + w.mu.Lock() + defer w.mu.Unlock() + client.connections-- + if client.connections == 0 { + client.Status = "running" } - return nextClient, index, nil } func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) { w.mu.Lock() defer w.mu.Unlock() - _, ok := w.clients[userId] - - if !ok { - w.clients = make(map[int64][]*Client) + id := strconv.FormatInt(userId, 10) + client, ok := w.clients[id] + if !ok || (client.Status == "idle" && client.Stop == nil) { middlewares := Middlewares(w.cnf, 5) - client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) - c := &Client{Tg: client, Status: "idle"} - w.clients[userId] = append(w.clients[userId], c) - } - nextClient := w.clients[userId][0] - if nextClient.Status == "idle" { - stop, err := Connect(nextClient.Tg, WithContext(w.ctx)) + tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) + client = &Client{Tg: tgClient, Status: "idle", UserId: id} + w.clients[id] = client + stop, err := Connect(client.Tg, WithContext(w.ctx)) if err != nil { return nil, err } - nextClient.Stop = stop - nextClient.Status = "running" + client.Stop = stop + w.logger.Debug("started bg client: ", client.UserId) } - return nextClient, nil + client.lastUsed = time.Now() + if client.connections == 0 { + client.Status = "serving" + } + client.connections++ + return client, nil +} + +func (w *StreamWorker) startIdleClientMonitor() { + ticker := time.NewTicker(w.cnf.BgBotsCheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + w.mu.Lock() + for _, client := range w.clients { + if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout { + if client.Stop != nil { + client.Stop() + client.Stop = nil + client.Status = "idle" + w.logger.Debug("stopped bg client: ", client.UserId) + } + } + } + w.mu.Unlock() + case <-w.ctx.Done(): + return + } + } + } func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker { return func(cnf *config.Config, kv kv.KV) *StreamWorker { - return &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx} + worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx, + clients: make(map[string]*Client), currIdx: make(map[int64]int), + channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)} + go worker.startIdleClientMonitor() + return worker } - } diff --git a/pkg/controller/file.go b/pkg/controller/file.go index af74eb5..9110a87 100644 --- a/pkg/controller/file.go +++ b/pkg/controller/file.go @@ -5,7 +5,6 @@ import ( "github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/cache" - "github.com/divyam234/teldrive/internal/logging" "github.com/divyam234/teldrive/pkg/httputil" "github.com/divyam234/teldrive/pkg/schemas" "github.com/gin-gonic/gin" @@ -15,10 +14,7 @@ func (fc *Controller) CreateFile(c *gin.Context) { var fileIn schemas.FileIn - logger := logging.FromContext(c) - if err := c.ShouldBindJSON(&fileIn); err != nil { - logger.Error(err) httputil.NewError(c, http.StatusBadRequest, err) return } @@ -27,7 +23,6 @@ func (fc *Controller) CreateFile(c *gin.Context) { res, err := fc.FileService.CreateFile(c, userId, &fileIn) if err != nil { - logger.Error(err) httputil.NewError(c, err.Code, err.Error) return } diff --git a/pkg/services/file.go b/pkg/services/file.go index 9171380..acb582b 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -648,6 +648,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { multiThreads = fs.cnf.Stream.MultiThreads + defer func() { + if client != nil { + fs.worker.Release(client) + } + }() + if fs.cnf.DisableStreamBots || len(tokens) == 0 { client, err = fs.worker.UserWorker(session.Session, session.UserId) if err != nil {