refactor: Copy session data to prevent mutation in LoadSession function

This commit is contained in:
divyam234 2024-06-06 00:19:30 +05:30
parent d959849278
commit b92f751375
9 changed files with 36 additions and 51 deletions

View file

@ -29,7 +29,9 @@ func (s *Session) LoadSession(_ context.Context) ([]byte, error) {
}
return nil, err
}
return b, nil
data := make([]byte, len(b))
copy(data, b)
return data, nil
}
func (s *Session) StoreSession(_ context.Context, data []byte) error {

View file

@ -6,7 +6,7 @@ import (
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)
type decrpytedReader struct {
@ -14,7 +14,7 @@ type decrpytedReader struct {
parts []types.Part
ranges []types.Range
pos int
client *telegram.Client
client *tg.Client
reader io.ReadCloser
limit int64
err error
@ -23,7 +23,7 @@ type decrpytedReader struct {
func NewDecryptedReader(
ctx context.Context,
client *telegram.Client,
client *tg.Client,
parts []types.Part,
start, end int64,
encryptionKey string) (io.ReadCloser, error) {

View file

@ -5,7 +5,7 @@ import (
"io"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
@ -40,14 +40,14 @@ type linearReader struct {
parts []types.Part
ranges []types.Range
pos int
client *telegram.Client
client *tg.Client
reader io.ReadCloser
limit int64
err error
}
func NewLinearReader(ctx context.Context,
client *telegram.Client,
client *tg.Client,
parts []types.Part,
start, end int64,
) (reader io.ReadCloser, err error) {

View file

@ -5,13 +5,12 @@ import (
"fmt"
"io"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)
type tgReader struct {
ctx context.Context
client *telegram.Client
client *tg.Client
location *tg.InputDocumentFileLocation
start int64
end int64
@ -34,7 +33,7 @@ func calculateChunkSize(start, end int64) int64 {
func newTGReader(
ctx context.Context,
client *telegram.Client,
client *tg.Client,
location *tg.InputDocumentFileLocation,
start int64,
end int64,
@ -95,7 +94,7 @@ func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
Precise: true,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
res, err := r.client.UploadGetFile(r.ctx, req)
if err != nil {
return nil, err

View file

@ -74,7 +74,7 @@ func NoAuthClient(ctx context.Context, config *config.TGConfig, handler telegram
return New(ctx, config, handler, storage, middlewares...)
}
func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string) (*telegram.Client, error) {
func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
data, err := session.TelethonSession(sessionStr)
if err != nil {
@ -89,43 +89,27 @@ func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string)
if err := loader.Save(context.TODO(), data); err != nil {
return nil, err
}
middlewares := []telegram.Middleware{
floodwait.NewSimpleWaiter(),
}
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*
time.Duration(config.Rate)), config.RateBurst))
return New(ctx, config, nil, storage, middlewares...)
}
func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, retries int, passMiddleware bool) (*telegram.Client, []telegram.Middleware, error) {
func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
storage := kv.NewSession(KV, kv.Key("botsession", token))
return New(ctx, config, nil, storage, middlewares...)
}
func Middlewares(config *config.TGConfig, retries int) []telegram.Middleware {
middlewares := []telegram.Middleware{
floodwait.NewSimpleWaiter(),
recovery.New(ctx, newBackoff(config.ReconnectTimeout)),
recovery.New(context.Background(), newBackoff(config.ReconnectTimeout)),
retry.New(retries),
}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*
time.Duration(config.Rate)), config.RateBurst))
}
if passMiddleware {
client, err := New(ctx, config, nil, storage, middlewares...)
if err != nil {
return nil, nil, err
}
return client, nil, nil
} else {
client, err := New(ctx, config, nil, storage)
if err != nil {
return nil, nil, err
}
return client, middlewares, nil
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
}
return middlewares
}

View file

@ -65,7 +65,8 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
w.currIdx = make(map[int64]int)
w.bots[channelId] = bots
for _, token := range bots {
client, _, _ := BotClient(w.ctx, w.kv, w.cnf, token, 5, true)
middlewares := Middlewares(w.cnf, 5)
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
w.clients[channelId] = append(w.clients[channelId], &Client{Tg: client, Status: "idle"})
}
w.currIdx[channelId] = 0
@ -90,7 +91,7 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
return nextClient, index, nil
}
func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Client, error) {
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()
@ -98,6 +99,8 @@ func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Clien
if !ok {
w.clients = make(map[int64][]*Client)
middlewares := Middlewares(w.cnf, 5)
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
w.clients[userId] = append(w.clients[userId], &Client{Tg: client, Status: "idle"})
}
nextClient := w.clients[userId][0]

View file

@ -111,7 +111,7 @@ func GetUserAuth(c *gin.Context) (int64, string) {
}
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
client, _, _ := tgc.BotClient(ctx, KV, config, token, 5, true)
client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
var user *tg.User
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)

View file

@ -615,8 +615,8 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
var client *tgc.Client
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
tgClient, _ := tgc.AuthClient(c, fs.cnf, session.Session)
client, err = fs.worker.UserWorker(tgClient, session.UserId)
client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil {
logger.Error("file stream", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -655,9 +655,9 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
}
if file.Encrypted {
lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey)
lr, err = reader.NewDecryptedReader(c, client.Tg.API(), parts, start, end, fs.cnf.Uploads.EncryptionKey)
} else {
lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end)
lr, err = reader.NewLinearReader(c, client.Tg.API(), parts, start, end)
}
if err != nil {

View file

@ -150,8 +150,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
} else {
us.worker.Set(tokens, channelId)
token, index = us.worker.Next(channelId)
client, middlewares, err = tgc.BotClient(c, us.kv, us.cnf, token, us.cnf.Uploads.MaxRetries, false)
client, err = tgc.BotClient(c, us.kv, us.cnf, token)
if err != nil {
return nil, &types.AppError{Error: err}
@ -160,13 +159,11 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
channelUser = strings.Split(token, ":")[0]
}
middlewares = tgc.Middlewares(us.cnf, us.cnf.Uploads.MaxRetries)
uploadPool := pool.NewPool(client, int64(us.cnf.PoolSize), middlewares...)
defer func() {
if uploadPool != nil {
uploadPool.Close()
}
}()
defer uploadPool.Close()
logger := logging.FromContext(c)