mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-02-24 15:05:41 +08:00
fix: downloads when no of chunks > 200
This commit is contained in:
parent
af5d7c97c9
commit
7bbcbf80a2
2 changed files with 78 additions and 11 deletions
|
@ -7,7 +7,10 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
|
@ -54,6 +57,12 @@ func randInt64() (int64, error) {
|
|||
b := &buffer{Buf: buf[:]}
|
||||
return b.long()
|
||||
}
|
||||
|
||||
type batchResult struct {
|
||||
Index int
|
||||
Messages *tg.MessagesChannelMessages
|
||||
}
|
||||
|
||||
func getChunk(ctx context.Context, tgClient *telegram.Client, location tg.InputFileLocationClass, offset int64, limit int64) ([]byte, error) {
|
||||
|
||||
req := &tg.UploadGetFileRequest{
|
||||
|
@ -115,11 +124,37 @@ func getBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st
|
|||
return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil
|
||||
}
|
||||
|
||||
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) (*tg.MessagesChannelMessages, error) {
|
||||
func getTGMessagesBatch(ctx context.Context, client *telegram.Client, channel *tg.InputChannel, parts []schemas.Part, userID string, index int,
|
||||
results chan<- batchResult, errors chan<- error, wg *sync.WaitGroup) {
|
||||
|
||||
defer wg.Done()
|
||||
|
||||
ids := funk.Map(parts, func(part schemas.Part) tg.InputMessageClass {
|
||||
return tg.InputMessageClass(&tg.InputMessageID{ID: int(part.ID)})
|
||||
})
|
||||
return &tg.InputMessageID{ID: int(part.ID)}
|
||||
}).([]tg.InputMessageClass)
|
||||
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: ids,
|
||||
}
|
||||
|
||||
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
messages, ok := res.(*tg.MessagesChannelMessages)
|
||||
|
||||
if !ok {
|
||||
errors <- fmt.Errorf("unexpected response type: %T", res)
|
||||
return
|
||||
}
|
||||
|
||||
results <- batchResult{Index: index, Messages: messages}
|
||||
}
|
||||
|
||||
func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas.Part, channelId int64, userID string) ([]tg.MessageClass, error) {
|
||||
|
||||
channel, err := GetChannelById(ctx, client, channelId, userID)
|
||||
|
||||
|
@ -127,17 +162,49 @@ func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas
|
|||
return nil, err
|
||||
}
|
||||
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{Channel: channel, ID: ids.([]tg.InputMessageClass)}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
res, err := client.API().ChannelsGetMessages(ctx, &messageRequest)
|
||||
batchSize := 200
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
batchCount := int(math.Ceil(float64(len(parts)) / float64(batchSize)))
|
||||
|
||||
results := make(chan batchResult, batchCount)
|
||||
|
||||
errors := make(chan error, batchCount)
|
||||
|
||||
for i := 0; i < batchCount; i++ {
|
||||
wg.Add(1)
|
||||
splitParts := parts[i*batchSize : min((i+1)*batchSize, len(parts))]
|
||||
go getTGMessagesBatch(ctx, client, channel, splitParts, userID, i, results, errors, &wg)
|
||||
}
|
||||
|
||||
messages := res.(*tg.MessagesChannelMessages)
|
||||
wg.Wait()
|
||||
close(results)
|
||||
close(errors)
|
||||
|
||||
return messages, nil
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
channelResult := []batchResult{}
|
||||
|
||||
for result := range results {
|
||||
channelResult = append(channelResult, result)
|
||||
}
|
||||
|
||||
sort.Slice(channelResult, func(i, j int) bool {
|
||||
return channelResult[i].Index < channelResult[j].Index
|
||||
})
|
||||
|
||||
allMessages := []tg.MessageClass{}
|
||||
|
||||
for _, result := range channelResult {
|
||||
allMessages = append(allMessages, result.Messages.GetMessages()...)
|
||||
}
|
||||
|
||||
return allMessages, nil
|
||||
}
|
||||
|
||||
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
|
||||
|
@ -158,7 +225,7 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
|
|||
return nil, err
|
||||
}
|
||||
|
||||
for i, message := range messages.Messages {
|
||||
for i, message := range messages {
|
||||
item := message.(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
|
|
|
@ -330,7 +330,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, message := range messages.Messages {
|
||||
for i, message := range messages {
|
||||
item := message.(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
|
|
Loading…
Reference in a new issue