teldrive/internal/reader/tg_multi_reader.go
2024-06-29 15:07:54 +05:30

281 lines
5 KiB
Go

package reader
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgMultiReader struct {
ctx context.Context
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
}
func newTGMultiReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgMultiReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
r := &tgMultiReader{
ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
}
go r.fillBufferConcurrently()
return r, nil
}
func (r *tgMultiReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
return nil
}
func (r *tgMultiReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
n := copy(p, r.cur.buffer())
r.cur.increment(n)
r.limit -= int64(n)
if r.limit <= 0 {
return n, io.EOF
}
return n, nil
}
func (r *tgMultiReader) fillBufferConcurrently() error {
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}