diff --git a/internal/reader/decrypted-reader.go b/internal/reader/decrypted-reader.go index 274c3de..fa643d7 100644 --- a/internal/reader/decrypted-reader.go +++ b/internal/reader/decrypted-reader.go @@ -12,6 +12,7 @@ import ( type decrpytedReader struct { ctx context.Context parts []types.Part + ranges []types.Range pos int client *telegram.Client reader io.ReadCloser @@ -24,14 +25,15 @@ func NewDecryptedReader( ctx context.Context, client *telegram.Client, parts []types.Part, - limit int64, + start, end int64, encryptionKey string) (io.ReadCloser, error) { r := &decrpytedReader{ ctx: ctx, parts: parts, client: client, - limit: limit, + limit: end - start + 1, + ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize), encryptionKey: encryptionKey, } res, err := r.nextPart() @@ -68,7 +70,7 @@ func (r *decrpytedReader) Read(p []byte) (n int, err error) { } r.pos++ if r.pos < len(r.parts) { - r.reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos]) + r.reader, err = r.nextPart() } } r.err = err @@ -86,7 +88,11 @@ func (r *decrpytedReader) Close() (err error) { func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { - cipher, _ := crypt.NewCipher(r.encryptionKey, r.parts[r.pos].Salt) + location := r.parts[r.ranges[r.pos].PartNo].Location + start := r.ranges[r.pos].Start + end := r.ranges[r.pos].End + salt := r.parts[r.ranges[r.pos].PartNo].Salt + cipher, _ := crypt.NewCipher(r.encryptionKey, salt) return cipher.DecryptDataSeek(r.ctx, func(ctx context.Context, @@ -96,14 +102,10 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { var end int64 if underlyingLimit >= 0 { - end = min(r.parts[r.pos].Size-1, underlyingOffset+underlyingLimit-1) + end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1) } - return newTGReader(r.ctx, r.client, types.Part{ - Start: underlyingOffset, - End: end, - Location: r.parts[r.pos].Location, - }) - }, r.parts[r.pos].Start, r.parts[r.pos].End-r.parts[r.pos].Start+1) + return newTGReader(r.ctx, r.client, location, underlyingOffset, end) + }, start, end-start+1) } diff --git a/internal/reader/reader.go b/internal/reader/reader.go index 33a890e..e75972c 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -8,9 +8,37 @@ import ( "github.com/gotd/td/telegram" ) +func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { + + partByteRanges := []types.Range{} + + startPart := startByte / partSize + + endPart := endByte / partSize + + startOffset := startByte % partSize + + for part := startPart; part <= endPart; part++ { + partStartByte := int64(0) + partEndByte := partSize - 1 + if part == startPart { + partStartByte = startOffset + } + if part == endPart { + partEndByte = int64(endByte % partSize) + } + partByteRanges = append(partByteRanges, types.Range{Start: partStartByte, End: partEndByte, PartNo: part}) + + startOffset = 0 + } + + return partByteRanges +} + type linearReader struct { ctx context.Context parts []types.Part + ranges []types.Range pos int client *telegram.Client reader io.ReadCloser @@ -21,22 +49,22 @@ type linearReader struct { func NewLinearReader(ctx context.Context, client *telegram.Client, parts []types.Part, - limit int64, + start, end int64, ) (reader io.ReadCloser, err error) { r := &linearReader{ ctx: ctx, parts: parts, client: client, - limit: limit, + limit: end - start + 1, + ranges: calculatePartByteRanges(start, end, parts[0].Size), } - reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos]) + r.reader, err = r.nextPart() if err != nil { return nil, err } - r.reader = reader return r, nil } @@ -62,14 +90,24 @@ func (r *linearReader) Read(p []byte) (n int, err error) { err = nil } r.pos++ - if r.pos < len(r.parts) { - r.reader, err = newTGReader(r.ctx, r.client, r.parts[r.pos]) + if r.pos < len(r.ranges) { + r.reader, err = r.nextPart() + } } r.err = err return } +func (r *linearReader) nextPart() (io.ReadCloser, error) { + + location := r.parts[r.ranges[r.pos].PartNo].Location + startByte := r.ranges[r.pos].Start + endByte := r.ranges[r.pos].End + + return newTGReader(r.ctx, r.client, location, startByte, endByte) +} + func (r *linearReader) Close() (err error) { if r.reader != nil { err = r.reader.Close() diff --git a/internal/reader/tgreader.go b/internal/reader/tgreader.go index e790e98..e1df44f 100644 --- a/internal/reader/tgreader.go +++ b/internal/reader/tgreader.go @@ -5,7 +5,6 @@ import ( "fmt" "io" - "github.com/divyam234/teldrive/pkg/types" "github.com/gotd/td/telegram" "github.com/gotd/td/tg" ) @@ -36,18 +35,20 @@ func calculateChunkSize(start, end int64) int64 { func newTGReader( ctx context.Context, client *telegram.Client, - part types.Part, + location *tg.InputDocumentFileLocation, + start int64, + end int64, ) (io.ReadCloser, error) { r := &tgReader{ ctx: ctx, - location: part.Location, + location: location, client: client, - start: part.Start, - end: part.End, - chunkSize: calculateChunkSize(part.Start, part.End), - limit: part.End - part.Start + 1, + start: start, + end: end, + chunkSize: calculateChunkSize(start, end), + limit: end - start + 1, } r.next = r.partStream() return r, nil diff --git a/internal/retry/retry.go b/internal/retry/retry.go index e5f5ff9..385cab8 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -17,6 +17,7 @@ var internalErrors = []string{ "RPC_CALL_FAIL", "RPC_MCGET_FAIL", "WORKER_BUSY_TOO_LONG_RETRY", + "memory limit exit", } type retry struct { diff --git a/pkg/services/common.go b/pkg/services/common.go index 022693c..6ecf4ec 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -163,64 +163,21 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu media := item.Media.(*tg.MessageMediaDocument) document := media.Document.(*tg.Document) location := document.AsInputDocumentFileLocation() - end := document.Size - 1 - if file.Encrypted { - end, _ = crypt.DecryptedSize(document.Size) - end -= 1 - } - parts = append(parts, types.Part{ + + part := types.Part{ Location: location, - End: end, Size: document.Size, Salt: file.Parts[i].Salt, - }) + } + if file.Encrypted { + part.DecryptedSize, _ = crypt.DecryptedSize(document.Size) + } + parts = append(parts, part) } cache.Set(key, &parts, 3600) return parts, nil } -func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part { - - chunkSize := parts[0].End + 1 - - numParts := int64(len(parts)) - - validParts := []types.Part{} - - firstChunk := max(startByte/chunkSize, 0) - - lastChunk := min(endByte/chunkSize, numParts) - - startInFirstChunk := startByte % chunkSize - - endInLastChunk := endByte % chunkSize - - if firstChunk == lastChunk { - part := parts[firstChunk] - part.Start = startInFirstChunk - part.End = endInLastChunk - validParts = append(validParts, part) - } else { - part := parts[firstChunk] - part.Start = startInFirstChunk - validParts = append(validParts, part) - // Add valid parts from any chunks in between. - for i := firstChunk + 1; i < lastChunk; i++ { - part := parts[i] - part.Start = 0 - validParts = append(validParts, part) - } - - // Add valid parts from the last chunk. - endPart := parts[lastChunk] - endPart.Start = 0 - endPart.End = endInLastChunk - validParts = append(validParts, endPart) - } - - return validParts -} - func GetChannelById(ctx context.Context, client *telegram.Client, channelId int64, userID string) (*tg.InputChannel, error) { channel := &tg.InputChannel{} diff --git a/pkg/services/file.go b/pkg/services/file.go index 12c7e92..da29c02 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -524,12 +524,20 @@ func (fs *FileService) GetFileStream(c *gin.Context) { return } - parts = rangedParts(parts, start, end) - if file.Encrypted { - lr, _ = reader.NewDecryptedReader(c, client.Tg, parts, contentLength, fs.cnf.Uploads.EncryptionKey) + lr, err = reader.NewDecryptedReader(c, client.Tg, parts, start, end, fs.cnf.Uploads.EncryptionKey) } else { - lr, _ = reader.NewLinearReader(c, client.Tg, parts, contentLength) + lr, err = reader.NewLinearReader(c, client.Tg, parts, start, end) + } + + if err != nil { + logger.Error("file stream", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if lr == nil { + http.Error(w, "failed to initialise reader", http.StatusInternalServerError) + return } io.CopyN(w, lr, contentLength) diff --git a/pkg/services/user.go b/pkg/services/user.go index 9881615..b26b758 100644 --- a/pkg/services/user.go +++ b/pkg/services/user.go @@ -234,10 +234,6 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI return err } - if err != nil { - return err - - } botInfoChannel := make(chan *types.BotInfo, len(botsTokens)) waitChan := make(chan struct{}, 6) @@ -261,9 +257,9 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI <-waitChan wg.Done() }() - if err == nil { - botInfoChannel <- info - } + + botInfoChannel <- info + }(token) } diff --git a/pkg/types/types.go b/pkg/types/types.go index 1f168b1..94e7cfc 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -12,11 +12,10 @@ type AppError struct { } type Part struct { - Location *tg.InputDocumentFileLocation - Start int64 - End int64 - Size int64 - Salt string + Location *tg.InputDocumentFileLocation + DecryptedSize int64 + Size int64 + Salt string } type JWTClaims struct { @@ -48,3 +47,9 @@ type BotInfo struct { AccessHash int64 Token string } + +type Range struct { + Start int64 + End int64 + PartNo int64 +}