mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-02 21:32:58 +08:00
feat: option to close bg bots after a certain time
This commit is contained in:
parent
64102e801b
commit
2c7db87b3f
7 changed files with 133 additions and 74 deletions
10
cmd/run.go
10
cmd/run.go
|
@ -48,7 +48,7 @@ func NewRun() *cobra.Command {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
runCmd.Flags().StringP("config", "c", "", "config file (default is $HOME/.teldrive/config.toml)")
|
runCmd.Flags().StringP("config", "c", "", "Config file path (default $HOME/.teldrive/config.toml)")
|
||||||
runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port")
|
runCmd.Flags().IntVarP(&config.Server.Port, "server-port", "p", 8080, "Server port")
|
||||||
duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
|
duration.DurationVar(runCmd.Flags(), &config.Server.GracefulShutdown, "server-graceful-shutdown", 15*time.Second, "Server graceful shutdown timeout")
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ func NewRun() *cobra.Command {
|
||||||
runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
|
runCmd.Flags().IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
|
||||||
runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
|
runCmd.Flags().StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
|
||||||
runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path")
|
runCmd.Flags().StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path")
|
||||||
runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting")
|
runCmd.Flags().BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client")
|
||||||
runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst")
|
runCmd.Flags().IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client")
|
||||||
runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate")
|
runCmd.Flags().IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client")
|
||||||
runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model",
|
runCmd.Flags().StringVar(&config.TG.DeviceModel, "tg-device-model",
|
||||||
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model")
|
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/116.0", "Device model")
|
||||||
runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version")
|
runCmd.Flags().StringVar(&config.TG.SystemVersion, "tg-system-version", "Win32", "System version")
|
||||||
|
@ -94,6 +94,8 @@ func NewRun() *cobra.Command {
|
||||||
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
||||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
||||||
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
||||||
|
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
|
||||||
|
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
|
||||||
runCmd.MarkFlagRequired("tg-app-id")
|
runCmd.MarkFlagRequired("tg-app-id")
|
||||||
runCmd.MarkFlagRequired("tg-app-hash")
|
runCmd.MarkFlagRequired("tg-app-hash")
|
||||||
runCmd.MarkFlagRequired("db-data-source")
|
runCmd.MarkFlagRequired("db-data-source")
|
||||||
|
|
|
@ -18,25 +18,27 @@ type ServerConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TGConfig struct {
|
type TGConfig struct {
|
||||||
AppId int
|
AppId int
|
||||||
AppHash string
|
AppHash string
|
||||||
RateLimit bool
|
RateLimit bool
|
||||||
RateBurst int
|
RateBurst int
|
||||||
Rate int
|
Rate int
|
||||||
DeviceModel string
|
DeviceModel string
|
||||||
SystemVersion string
|
SystemVersion string
|
||||||
AppVersion string
|
AppVersion string
|
||||||
LangCode string
|
LangCode string
|
||||||
SystemLangCode string
|
SystemLangCode string
|
||||||
LangPack string
|
LangPack string
|
||||||
SessionFile string
|
SessionFile string
|
||||||
BgBotsLimit int
|
BgBotsLimit int
|
||||||
DisableStreamBots bool
|
DisableStreamBots bool
|
||||||
Proxy string
|
BgBotsTimeout time.Duration
|
||||||
ReconnectTimeout time.Duration
|
BgBotsCheckInterval time.Duration
|
||||||
PoolSize int64
|
Proxy string
|
||||||
EnableLogging bool
|
ReconnectTimeout time.Duration
|
||||||
Uploads struct {
|
PoolSize int64
|
||||||
|
EnableLogging bool
|
||||||
|
Uploads struct {
|
||||||
EncryptionKey string
|
EncryptionKey string
|
||||||
Threads int
|
Threads int
|
||||||
MaxRetries int
|
MaxRetries int
|
||||||
|
|
|
@ -112,7 +112,7 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||||
client: r.client, concurrency: r.concurrency}
|
client: r.client, concurrency: r.concurrency}
|
||||||
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
||||||
|
|
||||||
}, start, end-start+1)
|
}, start, end-start+1)
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,12 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
||||||
|
|
||||||
client = c.client
|
client = c.client
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if c.concurrency > 0 && client != nil {
|
||||||
|
defer c.worker.Release(client)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if c.concurrency > 0 {
|
if c.concurrency > 0 {
|
||||||
client, _, _ = c.worker.Next(c.channelId)
|
client, _, _ = c.worker.Next(c.channelId)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,16 @@ package tgc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
"github.com/divyam234/teldrive/internal/kv"
|
"github.com/divyam234/teldrive/internal/kv"
|
||||||
|
"github.com/divyam234/teldrive/internal/logging"
|
||||||
"github.com/gotd/td/telegram"
|
"github.com/gotd/td/telegram"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UploadWorker struct {
|
type UploadWorker struct {
|
||||||
|
@ -41,37 +45,32 @@ func NewUploadWorker() *UploadWorker {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Tg *telegram.Client
|
Tg *telegram.Client
|
||||||
Stop StopFunc
|
Stop StopFunc
|
||||||
Status string
|
Status string
|
||||||
UserId string
|
UserId string
|
||||||
|
lastUsed time.Time
|
||||||
|
connections int
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamWorker struct {
|
type StreamWorker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
bots map[int64][]string
|
clients map[string]*Client
|
||||||
clients map[int64][]*Client
|
currIdx map[int64]int
|
||||||
currIdx map[int64]int
|
channelBots map[int64][]string
|
||||||
cnf *config.TGConfig
|
cnf *config.TGConfig
|
||||||
kv kv.KV
|
kv kv.KV
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
logger *zap.SugaredLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StreamWorker) Set(bots []string, channelId int64) {
|
func (w *StreamWorker) Set(bots []string, channelId int64) {
|
||||||
|
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
_, ok := w.bots[channelId]
|
_, ok := w.channelBots[channelId]
|
||||||
if !ok {
|
if !ok {
|
||||||
w.bots = make(map[int64][]string)
|
w.channelBots[channelId] = bots
|
||||||
w.clients = make(map[int64][]*Client)
|
|
||||||
w.currIdx = make(map[int64]int)
|
|
||||||
w.bots[channelId] = bots
|
|
||||||
for _, token := range bots {
|
|
||||||
middlewares := Middlewares(w.cnf, 5)
|
|
||||||
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
|
||||||
c := &Client{Tg: client, Status: "idle", UserId: strings.Split(token, ":")[0]}
|
|
||||||
w.clients[channelId] = append(w.clients[channelId], c)
|
|
||||||
}
|
|
||||||
w.currIdx[channelId] = 0
|
w.currIdx[channelId] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,47 +80,96 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
index := w.currIdx[channelId]
|
index := w.currIdx[channelId]
|
||||||
nextClient := w.clients[channelId][index]
|
token := w.channelBots[channelId][index]
|
||||||
w.currIdx[channelId] = (index + 1) % len(w.clients[channelId])
|
userId := strings.Split(token, ":")[0]
|
||||||
if nextClient.Status == "idle" {
|
client, ok := w.clients[userId]
|
||||||
stop, err := Connect(nextClient.Tg, WithBotToken(w.bots[channelId][index]))
|
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||||
|
middlewares := Middlewares(w.cnf, 5)
|
||||||
|
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||||
|
client = &Client{Tg: tgClient, Status: "idle", UserId: userId}
|
||||||
|
w.clients[userId] = client
|
||||||
|
stop, err := Connect(client.Tg, WithBotToken(token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
nextClient.Stop = stop
|
client.Stop = stop
|
||||||
nextClient.Status = "running"
|
w.logger.Debug("started bg client: ", client.UserId)
|
||||||
|
}
|
||||||
|
w.currIdx[channelId] = (index + 1) % len(w.channelBots[channelId])
|
||||||
|
client.lastUsed = time.Now()
|
||||||
|
if client.connections == 0 {
|
||||||
|
client.Status = "serving"
|
||||||
|
}
|
||||||
|
client.connections++
|
||||||
|
return client, index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StreamWorker) Release(client *Client) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
client.connections--
|
||||||
|
if client.connections == 0 {
|
||||||
|
client.Status = "running"
|
||||||
}
|
}
|
||||||
return nextClient, index, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
|
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
_, ok := w.clients[userId]
|
id := strconv.FormatInt(userId, 10)
|
||||||
|
client, ok := w.clients[id]
|
||||||
if !ok {
|
if !ok || (client.Status == "idle" && client.Stop == nil) {
|
||||||
w.clients = make(map[int64][]*Client)
|
|
||||||
middlewares := Middlewares(w.cnf, 5)
|
middlewares := Middlewares(w.cnf, 5)
|
||||||
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
tgClient, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||||
c := &Client{Tg: client, Status: "idle"}
|
client = &Client{Tg: tgClient, Status: "idle", UserId: id}
|
||||||
w.clients[userId] = append(w.clients[userId], c)
|
w.clients[id] = client
|
||||||
}
|
stop, err := Connect(client.Tg, WithContext(w.ctx))
|
||||||
nextClient := w.clients[userId][0]
|
|
||||||
if nextClient.Status == "idle" {
|
|
||||||
stop, err := Connect(nextClient.Tg, WithContext(w.ctx))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
nextClient.Stop = stop
|
client.Stop = stop
|
||||||
nextClient.Status = "running"
|
w.logger.Debug("started bg client: ", client.UserId)
|
||||||
}
|
}
|
||||||
return nextClient, nil
|
client.lastUsed = time.Now()
|
||||||
|
if client.connections == 0 {
|
||||||
|
client.Status = "serving"
|
||||||
|
}
|
||||||
|
client.connections++
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StreamWorker) startIdleClientMonitor() {
|
||||||
|
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
w.mu.Lock()
|
||||||
|
for _, client := range w.clients {
|
||||||
|
if client.Status == "running" && time.Since(client.lastUsed) > w.cnf.BgBotsTimeout {
|
||||||
|
if client.Stop != nil {
|
||||||
|
client.Stop()
|
||||||
|
client.Stop = nil
|
||||||
|
client.Status = "idle"
|
||||||
|
w.logger.Debug("stopped bg client: ", client.UserId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.mu.Unlock()
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
func NewStreamWorker(ctx context.Context) func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||||
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
return func(cnf *config.Config, kv kv.KV) *StreamWorker {
|
||||||
return &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx}
|
worker := &StreamWorker{cnf: &cnf.TG, kv: kv, ctx: ctx,
|
||||||
|
clients: make(map[string]*Client), currIdx: make(map[int64]int),
|
||||||
|
channelBots: make(map[int64][]string), logger: logging.FromContext(ctx)}
|
||||||
|
go worker.startIdleClientMonitor()
|
||||||
|
return worker
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/auth"
|
"github.com/divyam234/teldrive/internal/auth"
|
||||||
"github.com/divyam234/teldrive/internal/cache"
|
"github.com/divyam234/teldrive/internal/cache"
|
||||||
"github.com/divyam234/teldrive/internal/logging"
|
|
||||||
"github.com/divyam234/teldrive/pkg/httputil"
|
"github.com/divyam234/teldrive/pkg/httputil"
|
||||||
"github.com/divyam234/teldrive/pkg/schemas"
|
"github.com/divyam234/teldrive/pkg/schemas"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
@ -15,10 +14,7 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
||||||
|
|
||||||
var fileIn schemas.FileIn
|
var fileIn schemas.FileIn
|
||||||
|
|
||||||
logger := logging.FromContext(c)
|
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&fileIn); err != nil {
|
if err := c.ShouldBindJSON(&fileIn); err != nil {
|
||||||
logger.Error(err)
|
|
||||||
httputil.NewError(c, http.StatusBadRequest, err)
|
httputil.NewError(c, http.StatusBadRequest, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -27,7 +23,6 @@ func (fc *Controller) CreateFile(c *gin.Context) {
|
||||||
|
|
||||||
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
|
res, err := fc.FileService.CreateFile(c, userId, &fileIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
|
||||||
httputil.NewError(c, err.Code, err.Error)
|
httputil.NewError(c, err.Code, err.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -648,6 +648,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
||||||
|
|
||||||
multiThreads = fs.cnf.Stream.MultiThreads
|
multiThreads = fs.cnf.Stream.MultiThreads
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if client != nil {
|
||||||
|
fs.worker.Release(client)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
||||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue