mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-24 16:12:18 +08:00
252 lines
5 KiB
Go
252 lines
5 KiB
Go
package reader
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/divyam234/teldrive/internal/cache"
|
|
"github.com/divyam234/teldrive/internal/config"
|
|
"github.com/divyam234/teldrive/internal/tgc"
|
|
"github.com/gotd/td/tg"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
var (
|
|
ErrStreamAbandoned = errors.New("stream abandoned")
|
|
ErrChunkTimeout = errors.New("chunk fetch timed out")
|
|
)
|
|
|
|
type ChunkSource interface {
|
|
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
|
ChunkSize(start, end int64) int64
|
|
}
|
|
|
|
type chunkSource struct {
|
|
channelID int64
|
|
worker *tgc.StreamWorker
|
|
fileID string
|
|
partID int64
|
|
concurrency int
|
|
client *tgc.Client
|
|
cache cache.Cacher
|
|
}
|
|
|
|
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
|
return tgc.CalculateChunkSize(start, end)
|
|
}
|
|
|
|
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
|
var (
|
|
location *tg.InputDocumentFileLocation
|
|
err error
|
|
client *tgc.Client
|
|
)
|
|
|
|
client = c.client
|
|
|
|
defer func() {
|
|
if c.concurrency > 0 && client != nil {
|
|
defer c.worker.Release(client)
|
|
}
|
|
}()
|
|
|
|
if c.concurrency > 0 {
|
|
client, _, _ = c.worker.Next(c.channelID)
|
|
}
|
|
location, err = tgc.GetLocation(ctx, client, c.cache, c.fileID, c.channelID, c.partID)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
|
|
|
}
|
|
|
|
type tgMultiReader struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
offset int64
|
|
limit int64
|
|
chunkSize int64
|
|
bufferChan chan *buffer
|
|
cur *buffer
|
|
concurrency int
|
|
leftCut int64
|
|
rightCut int64
|
|
totalParts int
|
|
currentPart int
|
|
chunkSrc ChunkSource
|
|
timeout time.Duration
|
|
}
|
|
|
|
func newTGMultiReader(
|
|
ctx context.Context,
|
|
start int64,
|
|
end int64,
|
|
config *config.TGConfig,
|
|
chunkSrc ChunkSource,
|
|
) (*tgMultiReader, error) {
|
|
chunkSize := chunkSrc.ChunkSize(start, end)
|
|
offset := start - (start % chunkSize)
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
r := &tgMultiReader{
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
limit: end - start + 1,
|
|
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
|
concurrency: config.Stream.MultiThreads,
|
|
leftCut: start - offset,
|
|
rightCut: (end % chunkSize) + 1,
|
|
totalParts: int((end - offset + chunkSize) / chunkSize),
|
|
offset: offset,
|
|
chunkSize: chunkSize,
|
|
chunkSrc: chunkSrc,
|
|
timeout: config.Stream.ChunkTimeout,
|
|
}
|
|
|
|
go r.fillBufferConcurrently()
|
|
return r, nil
|
|
}
|
|
|
|
func (r *tgMultiReader) Close() error {
|
|
r.cancel()
|
|
return nil
|
|
}
|
|
|
|
func (r *tgMultiReader) Read(p []byte) (int, error) {
|
|
if r.limit <= 0 {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
if r.cur == nil || r.cur.isEmpty() {
|
|
select {
|
|
case cur, ok := <-r.bufferChan:
|
|
if !ok {
|
|
return 0, ErrStreamAbandoned
|
|
}
|
|
r.cur = cur
|
|
case <-r.ctx.Done():
|
|
return 0, r.ctx.Err()
|
|
}
|
|
}
|
|
|
|
n := copy(p, r.cur.buffer())
|
|
r.cur.increment(n)
|
|
r.limit -= int64(n)
|
|
|
|
if r.limit <= 0 {
|
|
return n, io.EOF
|
|
}
|
|
|
|
return n, nil
|
|
}
|
|
|
|
func (r *tgMultiReader) fillBufferConcurrently() {
|
|
defer close(r.bufferChan)
|
|
|
|
for r.currentPart < r.totalParts {
|
|
if err := r.fillBatch(); err != nil {
|
|
r.cancel()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *tgMultiReader) fillBatch() error {
|
|
g, ctx := errgroup.WithContext(r.ctx)
|
|
g.SetLimit(r.concurrency)
|
|
|
|
buffers := make([]*buffer, r.concurrency)
|
|
|
|
for i := 0; i < r.concurrency && r.currentPart+i < r.totalParts; i++ {
|
|
g.Go(func() error {
|
|
chunkCtx, cancel := context.WithTimeout(ctx, r.timeout)
|
|
defer cancel()
|
|
|
|
chunk, err := r.fetchChunkWithTimeout(chunkCtx, int64(i))
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return fmt.Errorf("chunk %d: %w", r.currentPart+i, ErrChunkTimeout)
|
|
}
|
|
return fmt.Errorf("chunk %d: %w", r.currentPart+i, err)
|
|
}
|
|
|
|
if r.totalParts == 1 {
|
|
chunk = chunk[r.leftCut:r.rightCut]
|
|
} else if r.currentPart+i == 0 {
|
|
chunk = chunk[r.leftCut:]
|
|
} else if r.currentPart+i+1 == r.totalParts {
|
|
chunk = chunk[:r.rightCut]
|
|
}
|
|
|
|
buffers[i] = &buffer{buf: chunk}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if err := g.Wait(); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, buf := range buffers {
|
|
if buf == nil {
|
|
break
|
|
}
|
|
select {
|
|
case r.bufferChan <- buf:
|
|
case <-r.ctx.Done():
|
|
return r.ctx.Err()
|
|
}
|
|
}
|
|
|
|
r.currentPart += r.concurrency
|
|
r.offset += r.chunkSize * int64(r.concurrency)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *tgMultiReader) fetchChunkWithTimeout(ctx context.Context, i int64) ([]byte, error) {
|
|
chunkChan := make(chan []byte, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+i*r.chunkSize, r.chunkSize)
|
|
if err != nil {
|
|
errChan <- err
|
|
} else {
|
|
chunkChan <- chunk
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case chunk := <-chunkChan:
|
|
return chunk, nil
|
|
case err := <-errChan:
|
|
return nil, err
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
type buffer struct {
|
|
buf []byte
|
|
offset int
|
|
}
|
|
|
|
func (b *buffer) isEmpty() bool {
|
|
return b == nil || len(b.buf)-b.offset <= 0
|
|
}
|
|
|
|
func (b *buffer) buffer() []byte {
|
|
return b.buf[b.offset:]
|
|
}
|
|
|
|
func (b *buffer) increment(n int) {
|
|
b.offset += n
|
|
}
|