diff --git a/pkg/services/common.go b/pkg/services/common.go index 6ecf4ec..0e99a24 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -7,7 +7,10 @@ import ( "encoding/binary" "fmt" "io" + "math" + "sort" "strconv" + "sync" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" @@ -54,6 +57,12 @@ func randInt64() (int64, error) { b := &buffer{Buf: buf[:]} return b.long() } + +type batchResult struct { + Index int + Messages *tg.MessagesChannelMessages +} + func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) { req := &tg.UploadGetFileRequest{ @@ -115,11 +124,37 @@ func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil } -func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) (*tg.MessagesChannelMessages, error) { +func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, userID string, index int, + results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) { + + defer wg.Done() ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass { - return tg.InputMessageClass(&tg.InputMessageID{ID: int(part.ID)}) - }) + return &tg.InputMessageID{ID: int(part.ID)} + }).([]tg.InputMessageClass) + + messageRequest := tg.ChannelsGetMessagesRequest{ + Channel: channel, + ID: ids, + } + + res, err := client.API().ChannelsGetMessages(ctx, &messageRequest) + if err != nil { + errors <- err + return + } + + messages, ok := res.(*tg.MessagesChannelMessages) + + if !ok { + errors <- fmt.Errorf("unexpected response type: %T", res) + return + } + + results <- batchResult{Index: index, Messages: messages} +} + +func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) { channel, err := GetChannelById(ctx, client, channelId, userID) @@ -127,17 +162,49 @@ func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas return nil, err } - messageRequest := tg.ChannelsGetMessagesRequest{Channel: channel, ID: ids.([]tg.InputMessageClass)} + var wg sync.WaitGroup - res, err := client.API().ChannelsGetMessages(ctx, &messageRequest) + batchSize := 200 - if err != nil { - return nil, err + batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize))) + + results := make(chan batchResult, batchCount) + + errors := make(chan error, batchCount) + + for i := 0; i < batchCount; i++ { + wg.Add(1) + splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))] + go getTGMessagesBatch(ctx, client, channel, splitParts, userID, i, results, errors, &wg) } - messages := res.(*tg.MessagesChannelMessages) + wg.Wait() + close(results) + close(errors) - return messages, nil + for err := range errors { + if err != nil { + return nil, err + } + } + + channelResult := []batchResult{} + + for result := range results { + channelResult = append(channelResult, result) + } + + sort.Slice(channelResult, func(i, j int) bool { + return channelResult[i].Index < channelResult[j].Index + }) + + allMessages := []tg.MessageClass{} + + for _, result := range channelResult { + allMessages = append(allMessages, result.Messages.GetMessages()...) + } + + return allMessages, nil } func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) { @@ -158,7 +225,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu return nil, err } - for i, message := range messages.Messages { + for i, message := range messages { item := message.(*tg.Message) media := item.Media.(*tg.MessageMediaDocument) document := media.Document.(*tg.Document) diff --git a/pkg/services/file.go b/pkg/services/file.go index fdd5bd0..6f4487b 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -330,7 +330,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr if err != nil { return err } - for i, message := range messages.Messages { + for i, message := range messages { item := message.(*tg.Message) media := item.Media.(*tg.MessageMediaDocument) document := media.Document.(*tg.Document)