refactor: Copy session data to prevent mutation in LoadSession function

This commit is contained in:
divyam234 2024-06-06 00:19:30 +05:30
parent d959849278
commit b92f751375
9 changed files with 36 additions and 51 deletions

View file

@ -29,7 +29,9 @@ func (s *Session) LoadSession(_ context.Context) ([]byte, error) {
} }
return nil, err return nil, err
} }
return b, nil data := make([]byte, len(b))
copy(data, b)
return data, nil
} }
func (s *Session) StoreSession(_ context.Context, data []byte) error { func (s *Session) StoreSession(_ context.Context, data []byte) error {

View file

@ -6,7 +6,7 @@ import (
"github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/pkg/types" "github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram" "github.com/gotd/td/tg"
) )
type decrpytedReader struct { type decrpytedReader struct {
@ -14,7 +14,7 @@ type decrpytedReader struct {
parts []types.Part parts []types.Part
ranges []types.Range ranges []types.Range
pos int pos int
client *telegram.Client client *tg.Client
reader io.ReadCloser reader io.ReadCloser
limit int64 limit int64
err error err error
@ -23,7 +23,7 @@ type decrpytedReader struct {
func NewDecryptedReader( func NewDecryptedReader(
ctx context.Context, ctx context.Context,
client *telegram.Client, client *tg.Client,
parts []types.Part, parts []types.Part,
start, end int64, start, end int64,
encryptionKey string) (io.ReadCloser, error) { encryptionKey string) (io.ReadCloser, error) {

View file

@ -5,7 +5,7 @@ import (
"io" "io"
"github.com/divyam234/teldrive/pkg/types" "github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram" "github.com/gotd/td/tg"
) )
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
@ -40,14 +40,14 @@ type linearReader struct {
parts []types.Part parts []types.Part
ranges []types.Range ranges []types.Range
pos int pos int
client *telegram.Client client *tg.Client
reader io.ReadCloser reader io.ReadCloser
limit int64 limit int64
err error err error
} }
func NewLinearReader(ctx context.Context, func NewLinearReader(ctx context.Context,
client *telegram.Client, client *tg.Client,
parts []types.Part, parts []types.Part,
start, end int64, start, end int64,
) (reader io.ReadCloser, err error) { ) (reader io.ReadCloser, err error) {

View file

@ -5,13 +5,12 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
) )
type tgReader struct { type tgReader struct {
ctx context.Context ctx context.Context
client *telegram.Client client *tg.Client
location *tg.InputDocumentFileLocation location *tg.InputDocumentFileLocation
start int64 start int64
end int64 end int64
@ -34,7 +33,7 @@ func calculateChunkSize(start, end int64) int64 {
func newTGReader( func newTGReader(
ctx context.Context, ctx context.Context,
client *telegram.Client, client *tg.Client,
location *tg.InputDocumentFileLocation, location *tg.InputDocumentFileLocation,
start int64, start int64,
end int64, end int64,
@ -95,7 +94,7 @@ func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
Precise: true, Precise: true,
} }
res, err := r.client.API().UploadGetFile(r.ctx, req) res, err := r.client.UploadGetFile(r.ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -74,7 +74,7 @@ func NoAuthClient(ctx context.Context, config *config.TGConfig, handler telegram
return New(ctx, config, handler, storage, middlewares...) return New(ctx, config, handler, storage, middlewares...)
} }
func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string) (*telegram.Client, error) { func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
data, err := session.TelethonSession(sessionStr) data, err := session.TelethonSession(sessionStr)
if err != nil { if err != nil {
@ -89,43 +89,27 @@ func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string)
if err := loader.Save(context.TODO(), data); err != nil { if err := loader.Save(context.TODO(), data); err != nil {
return nil, err return nil, err
} }
middlewares := []telegram.Middleware{
floodwait.NewSimpleWaiter(),
}
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*
time.Duration(config.Rate)), config.RateBurst))
return New(ctx, config, nil, storage, middlewares...) return New(ctx, config, nil, storage, middlewares...)
} }
func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, retries int, passMiddleware bool) (*telegram.Client, []telegram.Middleware, error) { func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
storage := kv.NewSession(KV, kv.Key("botsession", token)) storage := kv.NewSession(KV, kv.Key("botsession", token))
return New(ctx, config, nil, storage, middlewares...)
}
func Middlewares(config *config.TGConfig, retries int) []telegram.Middleware {
middlewares := []telegram.Middleware{ middlewares := []telegram.Middleware{
floodwait.NewSimpleWaiter(), floodwait.NewSimpleWaiter(),
recovery.New(ctx, newBackoff(config.ReconnectTimeout)), recovery.New(context.Background(), newBackoff(config.ReconnectTimeout)),
retry.New(retries), retry.New(retries),
} }
if config.RateLimit { if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond* middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
time.Duration(config.Rate)), config.RateBurst))
}
if passMiddleware {
client, err := New(ctx, config, nil, storage, middlewares...)
if err != nil {
return nil, nil, err
}
return client, nil, nil
} else {
client, err := New(ctx, config, nil, storage)
if err != nil {
return nil, nil, err
}
return client, middlewares, nil
} }
return middlewares
} }

View file

@ -65,7 +65,8 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
w.currIdx = make(map[int64]int) w.currIdx = make(map[int64]int)
w.bots[channelId] = bots w.bots[channelId] = bots
for _, token := range bots { for _, token := range bots {
client, _, _ := BotClient(w.ctx, w.kv, w.cnf, token, 5, true) middlewares := Middlewares(w.cnf, 5)
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
w.clients[channelId] = append(w.clients[channelId], &Client{Tg: client, Status: "idle"}) w.clients[channelId] = append(w.clients[channelId], &Client{Tg: client, Status: "idle"})
} }
w.currIdx[channelId] = 0 w.currIdx[channelId] = 0
@ -90,7 +91,7 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
return nextClient, index, nil return nextClient, index, nil
} }
func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Client, error) { func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@ -98,6 +99,8 @@ func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Clien
if !ok { if !ok {
w.clients = make(map[int64][]*Client) w.clients = make(map[int64][]*Client)
middlewares := Middlewares(w.cnf, 5)
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
w.clients[userId] = append(w.clients[userId], &Client{Tg: client, Status: "idle"}) w.clients[userId] = append(w.clients[userId], &Client{Tg: client, Status: "idle"})
} }
nextClient := w.clients[userId][0] nextClient := w.clients[userId][0]

View file

@ -111,7 +111,7 @@ func GetUserAuth(c *gin.Context) (int64, string) {
} }
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) { func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
client, _, _ := tgc.BotClient(ctx, KV, config, token, 5, true) client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
var user *tg.User var user *tg.User
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error { err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx) user, _ = client.Self(ctx)

View file

@ -615,8 +615,8 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
var client *tgc.Client var client *tgc.Client
if fs.cnf.DisableStreamBots || len(tokens) == 0 { if fs.cnf.DisableStreamBots || len(tokens) == 0 {
tgClient, _ := tgc.AuthClient(c, fs.cnf, session.Session)
client, err = fs.worker.UserWorker(tgClient, session.UserId) client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil { if err != nil {
logger.Error("file stream", zap.Error(err)) logger.Error("file stream", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -655,9 +655,9 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
} }
if file.Encrypted { if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey) lr, err = reader.NewDecryptedReader(c, client.Tg.API(), parts, start, end, fs.cnf.Uploads.EncryptionKey)
} else { } else {
lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end) lr, err = reader.NewLinearReader(c, client.Tg.API(), parts, start, end)
} }
if err != nil { if err != nil {

View file

@ -150,8 +150,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
} else { } else {
us.worker.Set(tokens, channelId) us.worker.Set(tokens, channelId)
token, index = us.worker.Next(channelId) token, index = us.worker.Next(channelId)
client, err = tgc.BotClient(c, us.kv, us.cnf, token)
client, middlewares, err = tgc.BotClient(c, us.kv, us.cnf, token, us.cnf.Uploads.MaxRetries, false)
if err != nil { if err != nil {
return nil, &types.AppError{Error: err} return nil, &types.AppError{Error: err}
@ -160,13 +159,11 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
channelUser = strings.Split(token, ":")[0] channelUser = strings.Split(token, ":")[0]
} }
middlewares = tgc.Middlewares(us.cnf, us.cnf.Uploads.MaxRetries)
uploadPool := pool.NewPool(client, int64(us.cnf.PoolSize), middlewares...) uploadPool := pool.NewPool(client, int64(us.cnf.PoolSize), middlewares...)
defer func() { defer uploadPool.Close()
if uploadPool != nil {
uploadPool.Close()
}
}()
logger := logging.FromContext(c) logger := logging.FromContext(c)