mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-02 21:32:58 +08:00
refactor: Copy session data to prevent mutation in LoadSession function
This commit is contained in:
parent
d959849278
commit
b92f751375
9 changed files with 36 additions and 51 deletions
|
@ -29,7 +29,9 @@ func (s *Session) LoadSession(_ context.Context) ([]byte, error) {
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
data := make([]byte, len(b))
|
||||
copy(data, b)
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *Session) StoreSession(_ context.Context, data []byte) error {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/divyam234/teldrive/internal/crypt"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type decrpytedReader struct {
|
||||
|
@ -14,7 +14,7 @@ type decrpytedReader struct {
|
|||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
client *telegram.Client
|
||||
client *tg.Client
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
err error
|
||||
|
@ -23,7 +23,7 @@ type decrpytedReader struct {
|
|||
|
||||
func NewDecryptedReader(
|
||||
ctx context.Context,
|
||||
client *telegram.Client,
|
||||
client *tg.Client,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
encryptionKey string) (io.ReadCloser, error) {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||
|
@ -40,14 +40,14 @@ type linearReader struct {
|
|||
parts []types.Part
|
||||
ranges []types.Range
|
||||
pos int
|
||||
client *telegram.Client
|
||||
client *tg.Client
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
err error
|
||||
}
|
||||
|
||||
func NewLinearReader(ctx context.Context,
|
||||
client *telegram.Client,
|
||||
client *tg.Client,
|
||||
parts []types.Part,
|
||||
start, end int64,
|
||||
) (reader io.ReadCloser, err error) {
|
||||
|
|
|
@ -5,13 +5,12 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type tgReader struct {
|
||||
ctx context.Context
|
||||
client *telegram.Client
|
||||
client *tg.Client
|
||||
location *tg.InputDocumentFileLocation
|
||||
start int64
|
||||
end int64
|
||||
|
@ -34,7 +33,7 @@ func calculateChunkSize(start, end int64) int64 {
|
|||
|
||||
func newTGReader(
|
||||
ctx context.Context,
|
||||
client *telegram.Client,
|
||||
client *tg.Client,
|
||||
location *tg.InputDocumentFileLocation,
|
||||
start int64,
|
||||
end int64,
|
||||
|
@ -95,7 +94,7 @@ func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
|
|||
Precise: true,
|
||||
}
|
||||
|
||||
res, err := r.client.API().UploadGetFile(r.ctx, req)
|
||||
res, err := r.client.UploadGetFile(r.ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -74,7 +74,7 @@ func NoAuthClient(ctx context.Context, config *config.TGConfig, handler telegram
|
|||
return New(ctx, config, handler, storage, middlewares...)
|
||||
}
|
||||
|
||||
func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string) (*telegram.Client, error) {
|
||||
func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
|
||||
data, err := session.TelethonSession(sessionStr)
|
||||
|
||||
if err != nil {
|
||||
|
@ -89,43 +89,27 @@ func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string)
|
|||
if err := loader.Save(context.TODO(), data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
middlewares := []telegram.Middleware{
|
||||
floodwait.NewSimpleWaiter(),
|
||||
}
|
||||
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*
|
||||
time.Duration(config.Rate)), config.RateBurst))
|
||||
return New(ctx, config, nil, storage, middlewares...)
|
||||
}
|
||||
|
||||
func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, retries int, passMiddleware bool) (*telegram.Client, []telegram.Middleware, error) {
|
||||
func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
|
||||
|
||||
storage := kv.NewSession(KV, kv.Key("botsession", token))
|
||||
|
||||
return New(ctx, config, nil, storage, middlewares...)
|
||||
|
||||
}
|
||||
|
||||
func Middlewares(config *config.TGConfig, retries int) []telegram.Middleware {
|
||||
middlewares := []telegram.Middleware{
|
||||
floodwait.NewSimpleWaiter(),
|
||||
recovery.New(ctx, newBackoff(config.ReconnectTimeout)),
|
||||
recovery.New(context.Background(), newBackoff(config.ReconnectTimeout)),
|
||||
retry.New(retries),
|
||||
}
|
||||
|
||||
if config.RateLimit {
|
||||
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*
|
||||
time.Duration(config.Rate)), config.RateBurst))
|
||||
}
|
||||
|
||||
if passMiddleware {
|
||||
client, err := New(ctx, config, nil, storage, middlewares...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
}
|
||||
return client, nil, nil
|
||||
} else {
|
||||
client, err := New(ctx, config, nil, storage)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return client, middlewares, nil
|
||||
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
|
||||
}
|
||||
return middlewares
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,8 @@ func (w *StreamWorker) Set(bots []string, channelId int64) {
|
|||
w.currIdx = make(map[int64]int)
|
||||
w.bots[channelId] = bots
|
||||
for _, token := range bots {
|
||||
client, _, _ := BotClient(w.ctx, w.kv, w.cnf, token, 5, true)
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
client, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
|
||||
w.clients[channelId] = append(w.clients[channelId], &Client{Tg: client, Status: "idle"})
|
||||
}
|
||||
w.currIdx[channelId] = 0
|
||||
|
@ -90,7 +91,7 @@ func (w *StreamWorker) Next(channelId int64) (*Client, int, error) {
|
|||
return nextClient, index, nil
|
||||
}
|
||||
|
||||
func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Client, error) {
|
||||
func (w *StreamWorker) UserWorker(session string, userId int64) (*Client, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
|
@ -98,6 +99,8 @@ func (w *StreamWorker) UserWorker(client *telegram.Client, userId int64) (*Clien
|
|||
|
||||
if !ok {
|
||||
w.clients = make(map[int64][]*Client)
|
||||
middlewares := Middlewares(w.cnf, 5)
|
||||
client, _ := AuthClient(w.ctx, w.cnf, session, middlewares...)
|
||||
w.clients[userId] = append(w.clients[userId], &Client{Tg: client, Status: "idle"})
|
||||
}
|
||||
nextClient := w.clients[userId][0]
|
||||
|
|
|
@ -111,7 +111,7 @@ func GetUserAuth(c *gin.Context) (int64, string) {
|
|||
}
|
||||
|
||||
func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
|
||||
client, _, _ := tgc.BotClient(ctx, KV, config, token, 5, true)
|
||||
client, _ := tgc.BotClient(ctx, KV, config, token, tgc.Middlewares(config, 5)...)
|
||||
var user *tg.User
|
||||
err := tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
||||
user, _ = client.Self(ctx)
|
||||
|
|
|
@ -615,8 +615,8 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
var client *tgc.Client
|
||||
|
||||
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
|
||||
tgClient, _ := tgc.AuthClient(c, fs.cnf, session.Session)
|
||||
client, err = fs.worker.UserWorker(tgClient, session.UserId)
|
||||
|
||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
@ -655,9 +655,9 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
|
|||
}
|
||||
|
||||
if file.Encrypted {
|
||||
lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey)
|
||||
lr, err = reader.NewDecryptedReader(c, client.Tg.API(), parts, start, end, fs.cnf.Uploads.EncryptionKey)
|
||||
} else {
|
||||
lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end)
|
||||
lr, err = reader.NewLinearReader(c, client.Tg.API(), parts, start, end)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -150,8 +150,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
|||
} else {
|
||||
us.worker.Set(tokens, channelId)
|
||||
token, index = us.worker.Next(channelId)
|
||||
|
||||
client, middlewares, err = tgc.BotClient(c, us.kv, us.cnf, token, us.cnf.Uploads.MaxRetries, false)
|
||||
client, err = tgc.BotClient(c, us.kv, us.cnf, token)
|
||||
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err}
|
||||
|
@ -160,13 +159,11 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
|
|||
channelUser = strings.Split(token, ":")[0]
|
||||
}
|
||||
|
||||
middlewares = tgc.Middlewares(us.cnf, us.cnf.Uploads.MaxRetries)
|
||||
|
||||
uploadPool := pool.NewPool(client, int64(us.cnf.PoolSize), middlewares...)
|
||||
|
||||
defer func() {
|
||||
if uploadPool != nil {
|
||||
uploadPool.Close()
|
||||
}
|
||||
}()
|
||||
defer uploadPool.Close()
|
||||
|
||||
logger := logging.FromContext(c)
|
||||
|
||||
|
|
Loading…
Reference in a new issue