diff --git a/internal/kv/session.go b/internal/kv/session.go index 9ee41c1..3e46620 100644 --- a/internal/kv/session.go +++ b/internal/kv/session.go @@ -29,7 +29,9 @@ func (s *Session) LoadSession(_ context.Context) ([]byte, error) { } return nil, err } - return b, nil + data := make([]byte, len(b)) + copy(data, b) + return data, nil } func (s *Session) StoreSession(_ context.Context, data []byte) error { diff --git a/internal/reader/decrypted-reader.go b/internal/reader/decrypted-reader.go index c99ee71..e451c58 100644 --- a/internal/reader/decrypted-reader.go +++ b/internal/reader/decrypted-reader.go @@ -6,7 +6,7 @@ import ( "github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/pkg/types" - "github.com/gotd/td/telegram" + "github.com/gotd/td/tg" ) type decrpytedReader struct { @@ -14,7 +14,7 @@ type decrpytedReader struct { parts []types.Part ranges []types.Range pos int - client *telegram.Client + client *tg.Client reader io.ReadCloser limit int64 err error @@ -23,7 +23,7 @@ type decrpytedReader struct { func NewDecryptedReader( ctx context.Context, - client *telegram.Client, + client *tg.Client, parts []types.Part, start, end int64, encryptionKey string) (io.ReadCloser, error) { diff --git a/internal/reader/reader.go b/internal/reader/reader.go index e75972c..16f4896 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -5,7 +5,7 @@ import ( "io" "github.com/divyam234/teldrive/pkg/types" - "github.com/gotd/td/telegram" + "github.com/gotd/td/tg" ) func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { @@ -40,14 +40,14 @@ type linearReader struct { parts []types.Part ranges []types.Range pos int - client *telegram.Client + client *tg.Client reader io.ReadCloser limit int64 err error } func NewLinearReader(ctx context.Context, - client *telegram.Client, + client *tg.Client, parts []types.Part, start, end int64, ) (reader io.ReadCloser, err error) { diff --git a/internal/reader/tgreader.go b/internal/reader/tgreader.go index e1df44f..2936512 100644 --- a/internal/reader/tgreader.go +++ b/internal/reader/tgreader.go @@ -5,13 +5,12 @@ import ( "fmt" "io" - "github.com/gotd/td/telegram" "github.com/gotd/td/tg" ) type tgReader struct { ctx context.Context - client *telegram.Client + client *tg.Client location *tg.InputDocumentFileLocation start int64 end int64 @@ -34,7 +33,7 @@ func calculateChunkSize(start, end int64) int64 { func newTGReader( ctx context.Context, - client *telegram.Client, + client *tg.Client, location *tg.InputDocumentFileLocation, start int64, end int64, @@ -95,7 +94,7 @@ func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) { Precise: true, } - res, err := r.client.API().UploadGetFile(r.ctx, req) + res, err := r.client.UploadGetFile(r.ctx, req) if err != nil { return nil, err diff --git a/internal/tgc/tgc.go b/internal/tgc/tgc.go index 266d7c8..b8b092b 100644 --- a/internal/tgc/tgc.go +++ b/internal/tgc/tgc.go @@ -74,7 +74,7 @@ func NoAuthClient(ctx context.Context, config *config.TGConfig, handler telegram return New(ctx, config, handler, storage, middlewares...) } -func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string) (*telegram.Client, error) { +func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string, middlewares ...telegram.Middleware) (*telegram.Client, error) { data, err := session.TelethonSession(sessionStr) if err != nil { @@ -89,43 +89,27 @@ func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string) if err := loader.Save(context.TODO(), data); err != nil { return nil, err } - middlewares := []telegram.Middleware{ - floodwait.NewSimpleWaiter(), - } - middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond* - time.Duration(config.Rate)), config.RateBurst)) return New(ctx, config, nil, storage, middlewares...) } -func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, retries int, passMiddleware bool) (*telegram.Client, []telegram.Middleware, error) { +func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) { storage := kv.NewSession(KV, kv.Key("botsession", token)) + return New(ctx, config, nil, storage, middlewares...) + +} + +func Middlewares(config *config.TGConfig, retries int) []telegram.Middleware { middlewares := []telegram.Middleware{ floodwait.NewSimpleWaiter(), - recovery.New(ctx, newBackoff(config.ReconnectTimeout)), + recovery.New(context.Background(), newBackoff(config.ReconnectTimeout)), retry.New(retries), } - if config.RateLimit { - middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond* - time.Duration(config.Rate)), config.RateBurst)) - } - - if passMiddleware { - client, err := New(ctx, config, nil, storage, middlewares...) - if err != nil { - return nil, nil, err - - } - return client, nil, nil - } else { - client, err := New(ctx, config, nil, storage) - if err != nil { - return nil, nil, err - } - return client, middlewares, nil + middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst)) } + return middlewares } diff --git a/internal/tgc/workers.go b/internal/tgc/workers.go index e36b5ba..5df8a31 100644 --- a/internal/tgc/workers.go +++ b/internal/tgc/workers.go @@ -65,7 +65,8 @@ func (w *StreamWorker) Set(bots []string, channelId int64) { w.currIdx = make(map[int64]int) w.bots[channelId] = bots for _, token := range bots { - client, _, _ := BotClient(w.ctx, w.kv, w.cnf, token, 5, true) + middlewares := Middlewares(w.cnf, 5) + client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...) w.clients[channelId] = append(w.clients[channelId], &Client{Tg: client, Status: "idle"}) } w.currIdx[channelId] = 0 @@ -90,7 +91,7 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) { return nextClient, index, nil } -func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Client, error) { +func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) { w.mu.Lock() defer w.mu.Unlock() @@ -98,6 +99,8 @@ func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Clien if !ok { w.clients = make(map[int64][]*Client) + middlewares := Middlewares(w.cnf, 5) + client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...) w.clients[userId] = append(w.clients[userId], &Client{Tg: client, Status: "idle"}) } nextClient := w.clients[userId][0] diff --git a/pkg/services/common.go b/pkg/services/common.go index 58967a1..5147528 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -111,7 +111,7 @@ func GetUserAuth(c *gin.Context) (int64, string) { } func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) { - client, _, _ := tgc.BotClient(ctx, KV, config, token, 5, true) + client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...) var user *tg.User err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error { user, _ = client.Self(ctx) diff --git a/pkg/services/file.go b/pkg/services/file.go index 9d6a2ce..e24ea5d 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -615,8 +615,8 @@ func (fs *FileService) GetFileStream(c *gin.Context) { var client *tgc.Client if fs.cnf.DisableStreamBots || len(tokens) == 0 { - tgClient, _ := tgc.AuthClient(c, fs.cnf, session.Session) - client, err = fs.worker.UserWorker(tgClient, session.UserId) + + client, err = fs.worker.UserWorker(session.Session, session.UserId) if err != nil { logger.Error("file stream", zap.Error(err)) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -655,9 +655,9 @@ func (fs *FileService) GetFileStream(c *gin.Context) { } if file.Encrypted { - lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey) + lr, err = reader.NewDecryptedReader(c, client.Tg.API(), parts, start, end, fs.cnf.Uploads.EncryptionKey) } else { - lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end) + lr, err = reader.NewLinearReader(c, client.Tg.API(), parts, start, end) } if err != nil { diff --git a/pkg/services/upload.go b/pkg/services/upload.go index d9727ec..4460c21 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -150,8 +150,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty } else { us.worker.Set(tokens, channelId) token, index = us.worker.Next(channelId) - - client, middlewares, err = tgc.BotClient(c, us.kv, us.cnf, token, us.cnf.Uploads.MaxRetries, false) + client, err = tgc.BotClient(c, us.kv, us.cnf, token) if err != nil { return nil, &types.AppError{Error: err} @@ -160,13 +159,11 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty channelUser = strings.Split(token, ":")[0] } + middlewares = tgc.Middlewares(us.cnf, us.cnf.Uploads.MaxRetries) + uploadPool := pool.NewPool(client, int64(us.cnf.PoolSize), middlewares...) - defer func() { - if uploadPool != nil { - uploadPool.Close() - } - }() + defer uploadPool.Close() logger := logging.FromContext(c)