feat: option to close bg bots after a certain time

This commit is contained in:
divyam234 2024-06-24 15:45:26 +05:30
parent 64102e801b
commit 2c7db87b3f
7 changed files with 133 additions and 74 deletions

View file

@ -48,7 +48,7 @@ func NewRun() *cobra.Command {
},
}
runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)")
runCmd.Flags().StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)")
runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port")
duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
@ -71,9 +71,9 @@ func NewRun() *cobra.Command {
runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path")
runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting")
runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst")
runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate")
runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client")
runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client")
runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client")
runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model",
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model")
runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version")
@ -94,6 +94,8 @@ func NewRun() *cobra.Command {
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
runCmd.MarkFlagRequired("tg-app-id")
runCmd.MarkFlagRequired("tg-app-hash")
runCmd.MarkFlagRequired("db-data-source")

View file

@ -18,25 +18,27 @@ type ServerConfig struct {
}
type TGConfig struct {
AppId int
AppHash string
RateLimit bool
RateBurst int
Rate int
DeviceModel string
SystemVersion string
AppVersion string
LangCode string
SystemLangCode string
LangPack string
SessionFile string
BgBotsLimit int
DisableStreamBots bool
Proxy string
ReconnectTimeout time.Duration
PoolSize int64
EnableLogging bool
Uploads struct {
AppId int
AppHash string
RateLimit bool
RateBurst int
Rate int
DeviceModel string
SystemVersion string
AppVersion string
LangCode string
SystemLangCode string
LangPack string
SessionFile string
BgBotsLimit int
DisableStreamBots bool
BgBotsTimeout time.Duration
BgBotsCheckInterval time.Duration
Proxy string
ReconnectTimeout time.Duration
PoolSize int64
EnableLogging bool
Uploads struct {
EncryptionKey string
Threads int
MaxRetries int

View file

@ -112,7 +112,7 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1)

View file

@ -43,6 +43,12 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}

View file

@ -2,12 +2,16 @@ package tgc
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/kv"
"github.com/divyam234/teldrive/internal/logging"
"github.com/gotd/td/telegram"
"go.uber.org/zap"
)
type UploadWorker struct {
@ -41,37 +45,32 @@ func NewUploadWorker() *UploadWorker {
}
type Client struct {
Tg *telegram.Client
Stop StopFunc
Status string
UserId string
Tg *telegram.Client
Stop StopFunc
Status string
UserId string
lastUsed time.Time
connections int
}
type StreamWorker struct {
mu sync.Mutex
bots map[int64][]string
clients map[int64][]*Client
currIdx map[int64]int
cnf *config.TGConfig
kv kv.KV
ctx context.Context
mu sync.Mutex
clients map[string]*Client
currIdx map[int64]int
channelBots map[int64][]string
cnf *config.TGConfig
kv kv.KV
ctx context.Context
logger *zap.SugaredLogger
}
func (w *StreamWorker) Set(bots []string, channelId int64) {
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.bots[channelId]
_, ok := w.channelBots[channelId]
if !ok {
w.bots = make(map[int64][]string)
w.clients = make(map[int64][]*Client)
w.currIdx = make(map[int64]int)
w.bots[channelId] = bots
for _, token := range bots {
middlewares := Middlewares(w.cnf, 5)
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
w.clients[channelId] = append(w.clients[channelId], c)
}
w.channelBots[channelId] = bots
w.currIdx[channelId] = 0
}
@ -81,47 +80,96 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
index := w.currIdx[channelId]
nextClient := w.clients[channelId][index]
w.currIdx[channelId] = (index + 1) % len(w.clients[channelId])
if nextClient.Status == "idle" {
stop, err := Connect(nextClient.Tg, WithBotToken(w.bots[channelId][index]))
token := w.channelBots[channelId][index]
userId := strings.Split(token, ":")[0]
client, ok := w.clients[userId]
if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserId: userId}
w.clients[userId] = client
stop, err := Connect(client.Tg, WithBotToken(token))
if err != nil {
return nil, 0, err
}
nextClient.Stop = stop
nextClient.Status = "running"
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserId)
}
w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId])
client.lastUsed = time.Now()
if client.connections == 0 {
client.Status = "serving"
}
client.connections++
return client, index, nil
}
func (w *StreamWorker) Release(client *Client) {
w.mu.Lock()
defer w.mu.Unlock()
client.connections--
if client.connections == 0 {
client.Status = "running"
}
return nextClient, index, nil
}
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.clients[userId]
if !ok {
w.clients = make(map[int64][]*Client)
id := strconv.FormatInt(userId, 10)
client, ok := w.clients[id]
if !ok || (client.Status == "idle" && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
c := &Client{Tg: client, Status: "idle"}
w.clients[userId] = append(w.clients[userId], c)
}
nextClient := w.clients[userId][0]
if nextClient.Status == "idle" {
stop, err := Connect(nextClient.Tg, WithContext(w.ctx))
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
client = &Client{Tg: tgClient, Status: "idle", UserId: id}
w.clients[id] = client
stop, err := Connect(client.Tg, WithContext(w.ctx))
if err != nil {
return nil, err
}
nextClient.Stop = stop
nextClient.Status = "running"
client.Stop = stop
w.logger.Debug("started bg client: ", client.UserId)
}
return nextClient, nil
client.lastUsed = time.Now()
if client.connections == 0 {
client.Status = "serving"
}
client.connections++
return client, nil
}
func (w *StreamWorker) startIdleClientMonitor() {
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
w.mu.Lock()
for _, client := range w.clients {
if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout {
if client.Stop != nil {
client.Stop()
client.Stop = nil
client.Status = "idle"
w.logger.Debug("stopped bg client: ", client.UserId)
}
}
}
w.mu.Unlock()
case <-w.ctx.Done():
return
}
}
}
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
return &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx}
worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx,
clients: make(map[string]*Client), currIdx: make(map[int64]int),
channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)}
go worker.startIdleClientMonitor()
return worker
}
}

View file

@ -5,7 +5,6 @@ import (
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/logging"
"github.com/divyam234/teldrive/pkg/httputil"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/gin-gonic/gin"
@ -15,10 +14,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
var fileIn schemas.FileIn
logger := logging.FromContext(c)
if err := c.ShouldBindJSON(&fileIn); err != nil {
logger.Error(err)
httputil.NewError(c, http.StatusBadRequest, err)
return
}
@ -27,7 +23,6 @@ func (fc *Controller) CreateFile(c *gin.Context) {
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
if err != nil {
logger.Error(err)
httputil.NewError(c, err.Code, err.Error)
return
}

View file

@ -648,6 +648,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads = fs.cnf.Stream.MultiThreads
defer func() {
if client != nil {
fs.worker.Release(client)
}
}()
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil {