teldrive/internal/reader/tg_multi_reader.go
2024-08-04 17:13:40 +05:30

252 lines
5 KiB
Go

package reader
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var (
ErrStreamAbandoned = errors.New("stream abandoned")
ErrChunkTimeout = errors.New("chunk fetch timed out")
)
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelID int64
worker *tgc.StreamWorker
fileID string
partID int64
concurrency int
client *tgc.Client
cache cache.Cacher
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelID)
}
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgMultiReader struct {
ctx context.Context
cancel context.CancelFunc
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
cur *buffer
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
chunkSrc ChunkSource
timeout time.Duration
}
func newTGMultiReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgMultiReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
ctx, cancel := context.WithCancel(ctx)
r := &tgMultiReader{
ctx: ctx,
cancel: cancel,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
}
go r.fillBufferConcurrently()
return r, nil
}
func (r *tgMultiReader) Close() error {
r.cancel()
return nil
}
func (r *tgMultiReader) Read(p []byte) (int, error) {
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur == nil || r.cur.isEmpty() {
select {
case cur, ok := <-r.bufferChan:
if !ok {
return 0, ErrStreamAbandoned
}
r.cur = cur
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
n := copy(p, r.cur.buffer())
r.cur.increment(n)
r.limit -= int64(n)
if r.limit <= 0 {
return n, io.EOF
}
return n, nil
}
func (r *tgMultiReader) fillBufferConcurrently() {
defer close(r.bufferChan)
for r.currentPart < r.totalParts {
if err := r.fillBatch(); err != nil {
r.cancel()
return
}
}
}
func (r *tgMultiReader) fillBatch() error {
g, ctx := errgroup.WithContext(r.ctx)
g.SetLimit(r.concurrency)
buffers := make([]*buffer, r.concurrency)
for i := 0; i < r.concurrency && r.currentPart+i < r.totalParts; i++ {
g.Go(func() error {
chunkCtx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
chunk, err := r.fetchChunkWithTimeout(chunkCtx, int64(i))
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("chunk %d: %w", r.currentPart+i, ErrChunkTimeout)
}
return fmt.Errorf("chunk %d: %w", r.currentPart+i, err)
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i == 0 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buffers[i] = &buffer{buf: chunk}
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
for _, buf := range buffers {
if buf == nil {
break
}
select {
case r.bufferChan <- buf:
case <-r.ctx.Done():
return r.ctx.Err()
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
return nil
}
func (r *tgMultiReader) fetchChunkWithTimeout(ctx context.Context, i int64) ([]byte, error) {
chunkChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+i*r.chunkSize, r.chunkSize)
if err != nil {
errChan <- err
} else {
chunkChan <- chunk
}
}()
select {
case chunk := <-chunkChan:
return chunk, nil
case err := <-errChan:
return nil, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
return b == nil || len(b.buf)-b.offset <= 0
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}