fix: downloads when no of chunks > 200

This commit is contained in:
divyam234 2024-05-20 01:43:26 +05:30
parent af5d7c97c9
commit 7bbcbf80a2
2 changed files with 78 additions and 11 deletions

View file

@ -7,7 +7,10 @@ import (
"encoding/binary"
"fmt"
"io"
"math"
"sort"
"strconv"
"sync"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
@ -54,6 +57,12 @@ func randInt64() (int64, error) {
b := &buffer{Buf: buf[:]}
return b.long()
}
type batchResult struct {
Index int
Messages *tg.MessagesChannelMessages
}
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
@ -115,11 +124,37 @@ func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
}
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) (*tg.MessagesChannelMessages, error) {
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, userID string, index int,
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
return tg.InputMessageClass(&tg.InputMessageID{ID: int(part.ID)})
})
return &tg.InputMessageID{ID: int(part.ID)}
}).([]tg.InputMessageClass)
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: ids,
}
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
if err != nil {
errors <- err
return
}
messages, ok := res.(*tg.MessagesChannelMessages)
if !ok {
errors <- fmt.Errorf("unexpected response type: %T", res)
return
}
results <- batchResult{Index: index, Messages: messages}
}
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
channel, err := GetChannelById(ctx, client, channelId, userID)
@ -127,17 +162,49 @@ func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas
return nil, err
}
messageRequest := tg.ChannelsGetMessagesRequest{Channel: channel, ID: ids.([]tg.InputMessageClass)}
var wg sync.WaitGroup
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
batchSize := 200
if err != nil {
return nil, err
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
results := make(chan batchResult, batchCount)
errors := make(chan error, batchCount)
for i := 0; i < batchCount; i++ {
wg.Add(1)
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
go getTGMessagesBatch(ctx, client, channel, splitParts, userID, i, results, errors, &wg)
}
messages := res.(*tg.MessagesChannelMessages)
wg.Wait()
close(results)
close(errors)
return messages, nil
for err := range errors {
if err != nil {
return nil, err
}
}
channelResult := []batchResult{}
for result := range results {
channelResult = append(channelResult, result)
}
sort.Slice(channelResult, func(i, j int) bool {
return channelResult[i].Index < channelResult[j].Index
})
allMessages := []tg.MessageClass{}
for _, result := range channelResult {
allMessages = append(allMessages, result.Messages.GetMessages()...)
}
return allMessages, nil
}
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
@ -158,7 +225,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
return nil, err
}
for i, message := range messages.Messages {
for i, message := range messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)

View file

@ -330,7 +330,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
if err != nil {
return err
}
for i, message := range messages.Messages {
for i, message := range messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)