refactor: reader

This commit is contained in:
divyam234 2024-06-29 15:07:54 +05:30
parent 04755c76e6
commit c5cd24bbb3
5 changed files with 323 additions and 269 deletions

View file

@ -112,7 +112,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency} client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
}
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1) }, start, end-start+1)

View file

@ -117,7 +117,10 @@ func (r *linearReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency} client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc) if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
}
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
} }

View file

@ -0,0 +1,281 @@
package reader
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgMultiReader struct {
ctx context.Context
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
}
func newTGMultiReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgMultiReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
r := &tgMultiReader{
ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
}
go r.fillBufferConcurrently()
return r, nil
}
func (r *tgMultiReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
return nil
}
func (r *tgMultiReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
n := copy(p, r.cur.buffer())
r.cur.increment(n)
r.limit -= int64(n)
if r.limit <= 0 {
return n, io.EOF
}
return n, nil
}
func (r *tgMultiReader) fillBufferConcurrently() error {
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}

View file

@ -59,7 +59,7 @@ func (suite *TestSuite) TestFullRead() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSource{buffer: data} chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader) test_data, err := io.ReadAll(reader)
assert.Equal(suite.T(), nil, err) assert.Equal(suite.T(), nil, err)
@ -73,7 +73,7 @@ func (suite *TestSuite) TestPartialRead() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSource{buffer: data} chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader) test_data, err := io.ReadAll(reader)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -87,7 +87,7 @@ func (suite *TestSuite) TestTimeout() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data} chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader) test_data, err := io.ReadAll(reader)
assert.Greater(suite.T(), len(test_data), 0) assert.Greater(suite.T(), len(test_data), 0)
@ -101,7 +101,7 @@ func (suite *TestSuite) TestClose() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSource{buffer: data} chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader) _, err = io.ReadAll(reader)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
@ -115,7 +115,7 @@ func (suite *TestSuite) TestCancellation() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSource{buffer: data} chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
cancel() cancel()
_, err = io.ReadAll(reader) _, err = io.ReadAll(reader)
@ -131,7 +131,7 @@ func (suite *TestSuite) TestCancellationWithTimeout() {
data := make([]byte, 100) data := make([]byte, 100)
rand.Read(data) rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data} chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader) _, err = io.ReadAll(reader)
assert.Equal(suite.T(), err, context.DeadlineExceeded) assert.Equal(suite.T(), err, context.DeadlineExceeded)

View file

@ -2,94 +2,30 @@ package reader
import ( import (
"context" "context"
"errors"
"fmt"
"io" "io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
) )
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgReader struct { type tgReader struct {
ctx context.Context ctx context.Context
cur *buffer
offset int64 offset int64
limit int64 limit int64
chunkSize int64 chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64 leftCut int64
rightCut int64 rightCut int64
totalParts int totalParts int
currentPart int currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource chunkSrc ChunkSource
err error
} }
func newTGReader( func newTGReader(
ctx context.Context, ctx context.Context,
start int64, start int64,
end int64, end int64,
config *config.TGConfig,
chunkSrc ChunkSource, chunkSrc ChunkSource,
) (*tgReader, error) { ) (io.ReadCloser, error) {
chunkSize := chunkSrc.ChunkSize(start, end) chunkSize := chunkSrc.ChunkSize(start, end)
@ -97,70 +33,28 @@ func newTGReader(
r := &tgReader{ r := &tgReader{
ctx: ctx, ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset, leftCut: start - offset,
rightCut: (end % chunkSize) + 1, rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize), totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset, offset: offset,
limit: end - start + 1,
chunkSize: chunkSize, chunkSize: chunkSize,
chunkSrc: chunkSrc, chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout, currentPart: 1,
done: make(chan struct{}, 1),
err: make(chan error, 1),
} }
if r.concurrency == 0 {
r.currentPart = 1
go r.fillBufferSequentially()
} else {
go r.fillBufferConcurrently()
}
return r, nil return r, nil
} }
func (r *tgReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
return nil
}
func (r *tgReader) Read(p []byte) (int, error) { func (r *tgReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 { if r.limit <= 0 {
return 0, io.EOF return 0, io.EOF
} }
if r.cur.isEmpty() { if r.cur.isEmpty() {
if r.cur != nil { r.cur, r.err = r.next()
r.cur = nil if r.err != nil {
} return 0, r.err
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
} }
} }
@ -175,157 +69,29 @@ func (r *tgReader) Read(p []byte) (int, error) {
return n, nil return n, nil
} }
func (r *tgReader) fillBufferConcurrently() error { func (*tgReader) Close() error {
return nil
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
} }
func (r *tgReader) fillBufferSequentially() error { func (r *tgReader) next() (*buffer, error) {
fetchChunk := func(ctx context.Context) (*buffer, error) { if r.currentPart > r.totalParts {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize) return nil, io.EOF
if err != nil { }
return nil, err chunk, err := r.chunkSrc.Chunk(r.ctx, r.offset, r.chunkSize)
} if err != nil {
if r.totalParts == 1 { return nil, err
chunk = chunk[r.leftCut:r.rightCut] }
} else if r.currentPart == 1 { if r.totalParts == 1 {
chunk = chunk[r.leftCut:] chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart == r.totalParts { } else if r.currentPart == 1 {
chunk = chunk[:r.rightCut] chunk = chunk[r.leftCut:]
} } else if r.currentPart == r.totalParts {
return &buffer{buf: chunk}, nil chunk = chunk[:r.rightCut]
} }
for { r.currentPart++
select { r.offset += r.chunkSize
case <-r.done: return &buffer{buf: chunk}, nil
return nil
case <-r.ctx.Done():
return r.ctx.Err()
case <-time.After(r.timeout):
return nil
default:
buf, err := fetchChunk(r.ctx)
if err != nil {
r.err <- err
return nil
}
if r.closed {
return nil
}
r.bufferChan <- buf
r.currentPart++
r.offset += r.chunkSize
if r.currentPart > r.totalParts {
return nil
}
}
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
} }