feat: option to close bg bots after a certain time

This commit is contained in:
divyam234 2024-06-24 15:45:26 +05:30
parent 64102e801b
commit 2c7db87b3f
7 changed files with 133 additions and 74 deletions

View file

@ -48,7 +48,7 @@ func NewRun() *cobra.Command {
}, },
} }
runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)") runCmd.Flags().StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)")
runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port") runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port")
duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout") duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
@ -71,9 +71,9 @@ func NewRun() *cobra.Command {
runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID") runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash") runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path") runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path")
runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting") runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client")
runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst") runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client")
runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate") runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client")
runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model", runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model") "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model")
runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version") runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version")
@ -94,6 +94,8 @@ func NewRun() *cobra.Command {
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout") duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
runCmd.MarkFlagRequired("tg-app-id") runCmd.MarkFlagRequired("tg-app-id")
runCmd.MarkFlagRequired("tg-app-hash") runCmd.MarkFlagRequired("tg-app-hash")
runCmd.MarkFlagRequired("db-data-source") runCmd.MarkFlagRequired("db-data-source")

View file

@ -18,25 +18,27 @@ type ServerConfig struct {
} }
type TGConfig struct { type TGConfig struct {
AppId int AppId int
AppHash string AppHash string
RateLimit bool RateLimit bool
RateBurst int RateBurst int
Rate int Rate int
DeviceModel string DeviceModel string
SystemVersion string SystemVersion string
AppVersion string AppVersion string
LangCode string LangCode string
SystemLangCode string SystemLangCode string
LangPack string LangPack string
SessionFile string SessionFile string
BgBotsLimit int BgBotsLimit int
DisableStreamBots bool DisableStreamBots bool
Proxy string BgBotsTimeout time.Duration
ReconnectTimeout time.Duration BgBotsCheckInterval time.Duration
PoolSize int64 Proxy string
EnableLogging bool ReconnectTimeout time.Duration
Uploads struct { PoolSize int64
EnableLogging bool
Uploads struct {
EncryptionKey string EncryptionKey string
Threads int Threads int
MaxRetries int MaxRetries int

View file

@ -112,7 +112,7 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency} client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc) return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1) }, start, end-start+1)

View file

@ -43,6 +43,12 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
client = c.client client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 { if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId) client, _, _ = c.worker.Next(c.channelId)
} }

View file

@ -2,12 +2,16 @@ package tgc
import ( import (
"context" "context"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/logging"
"github.com/gotd/td/telegram" "github.com/gotd/td/telegram"
"go.uber.org/zap"
) )
type UploadWorker struct { type UploadWorker struct {
@ -41,37 +45,32 @@ func NewUploadWorker() *UploadWorker {
} }
type Client struct { type Client struct {
Tg *telegram.Client Tg *telegram.Client
Stop StopFunc Stop StopFunc
Status string Status string
UserId string UserId string
lastUsed time.Time
connections int
} }
type StreamWorker struct { type StreamWorker struct {
mu sync.Mutex mu sync.Mutex
bots map[int64][]string clients map[string]*Client
clients map[int64][]*Client currIdx map[int64]int
currIdx map[int64]int channelBots map[int64][]string
cnf *config.TGConfig cnf *config.TGConfig
kv kv.KV kv kv.KV
ctx context.Context ctx context.Context
logger *zap.SugaredLogger
} }
func (w *StreamWorker) Set(bots []string, channelId int64) { func (w *StreamWorker) Set(bots []string, channelId int64) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
_, ok := w.bots[channelId] _, ok := w.channelBots[channelId]
if !ok { if !ok {
w.bots = make(map[int64][]string) w.channelBots[channelId] = bots
w.clients = make(map[int64][]*Client)
w.currIdx = make(map[int64]int)
w.bots[channelId] = bots
for _, token := range bots {
middlewares := Middlewares(w.cnf, 5)
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
w.clients[channelId] = append(w.clients[channelId], c)
}
w.currIdx[channelId] = 0 w.currIdx[channelId] = 0
} }
@ -81,47 +80,96 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
index := w.currIdx[channelId] index := w.currIdx[channelId]
nextClient := w.clients[channelId][index] token := w.channelBots[channelId][index]
w.currIdx[channelId] = (index + 1) % len(w.clients[channelId]) userId := strings.Split(token, ":")[0]
if nextClient.Status == "idle" { client, ok := w.clients[userId]
stop, err := Connect(nextClient.Tg, WithBotToken(w.bots[channelId][index])) if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserId: userId}
w.clients[userId] = client
stop, err := Connect(client.Tg, WithBotToken(token))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
nextClient.Stop = stop client.Stop = stop
nextClient.Status = "running" w.logger.Debug("started bg client: ", client.UserId)
}
w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId])
client.lastUsed = time.Now()
if client.connections == 0 {
client.Status = "serving"
}
client.connections++
return client, index, nil
}
func (w *StreamWorker) Release(client *Client) {
w.mu.Lock()
defer w.mu.Unlock()
client.connections--
if client.connections == 0 {
client.Status = "running"
} }
return nextClient, index, nil
} }
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) { func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
_, ok := w.clients[userId] id := strconv.FormatInt(userId, 10)
client, ok := w.clients[id]
if !ok { if !ok || (client.Status == "idle" && client.Stop == nil) {
w.clients = make(map[int64][]*Client)
middlewares := Middlewares(w.cnf, 5) middlewares := Middlewares(w.cnf, 5)
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
c := &Client{Tg: client, Status: "idle"} client = &Client{Tg: tgClient, Status: "idle", UserId: id}
w.clients[userId] = append(w.clients[userId], c) w.clients[id] = client
} stop, err := Connect(client.Tg, WithContext(w.ctx))
nextClient := w.clients[userId][0]
if nextClient.Status == "idle" {
stop, err := Connect(nextClient.Tg, WithContext(w.ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
nextClient.Stop = stop client.Stop = stop
nextClient.Status = "running" w.logger.Debug("started bg client: ", client.UserId)
} }
return nextClient, nil client.lastUsed = time.Now()
if client.connections == 0 {
client.Status = "serving"
}
client.connections++
return client, nil
}
func (w *StreamWorker) startIdleClientMonitor() {
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
w.mu.Lock()
for _, client := range w.clients {
if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout {
if client.Stop != nil {
client.Stop()
client.Stop = nil
client.Status = "idle"
w.logger.Debug("stopped bg client: ", client.UserId)
}
}
}
w.mu.Unlock()
case <-w.ctx.Done():
return
}
}
} }
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker { func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
return func(cnf *config.Config, kv kv.KV) *StreamWorker { return func(cnf *config.Config, kv kv.KV) *StreamWorker {
return &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx} worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx,
clients: make(map[string]*Client), currIdx: make(map[int64]int),
channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)}
go worker.startIdleClientMonitor()
return worker
} }
} }

View file

@ -5,7 +5,6 @@ import (
"github.com/divyam234/teldrive/internal/auth" "github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/logging"
"github.com/divyam234/teldrive/pkg/httputil" "github.com/divyam234/teldrive/pkg/httputil"
"github.com/divyam234/teldrive/pkg/schemas" "github.com/divyam234/teldrive/pkg/schemas"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -15,10 +14,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
var fileIn schemas.FileIn var fileIn schemas.FileIn
logger := logging.FromContext(c)
if err := c.ShouldBindJSON(&fileIn); err != nil { if err := c.ShouldBindJSON(&fileIn); err != nil {
logger.Error(err)
httputil.NewError(c, http.StatusBadRequest, err) httputil.NewError(c, http.StatusBadRequest, err)
return return
} }
@ -27,7 +23,6 @@ func (fc *Controller) CreateFile(c *gin.Context) {
res, err := fc.FileService.CreateFile(c, userId, &fileIn) res, err := fc.FileService.CreateFile(c, userId, &fileIn)
if err != nil { if err != nil {
logger.Error(err)
httputil.NewError(c, err.Code, err.Error) httputil.NewError(c, err.Code, err.Error)
return return
} }

View file

@ -648,6 +648,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads = fs.cnf.Stream.MultiThreads multiThreads = fs.cnf.Stream.MultiThreads
defer func() {
if client != nil {
fs.worker.Release(client)
}
}()
if fs.cnf.DisableStreamBots || len(tokens) == 0 { if fs.cnf.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId) client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil { if err != nil {