From c5cd24bbb39a6a6dedc754133e64b2f8c6cddcfb Mon Sep 17 00:00:00 2001 From: divyam234 <47589864+divyam234@users.noreply.github.com> Date: Sat, 29 Jun 2024 15:07:54 +0530 Subject: [PATCH] refactor: reader --- internal/reader/decrypted_reader.go | 6 +- internal/reader/reader.go | 5 +- internal/reader/tg_multi_reader.go | 281 +++++++++++++++++ ...reader_test.go => tg_multi_reader_test.go} | 12 +- internal/reader/tg_reader.go | 288 ++---------------- 5 files changed, 323 insertions(+), 269 deletions(-) create mode 100644 internal/reader/tg_multi_reader.go rename internal/reader/{tg_reader_test.go => tg_multi_reader_test.go} (88%) diff --git a/internal/reader/decrypted_reader.go b/internal/reader/decrypted_reader.go index effae79..e9fd691 100644 --- a/internal/reader/decrypted_reader.go +++ b/internal/reader/decrypted_reader.go @@ -112,7 +112,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, client: r.client, concurrency: r.concurrency} - return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc) + + if r.concurrency < 2 { + return newTGReader(r.ctx, start, end, chunkSrc) + } + return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc) }, start, end-start+1) diff --git a/internal/reader/reader.go b/internal/reader/reader.go index 68c1cc6..fb896ca 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -117,7 +117,10 @@ func (r *linearReader) nextPart() (io.ReadCloser, error) { chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, client: r.client, concurrency: r.concurrency} - return newTGReader(r.ctx, start, end, r.config, chunkSrc) + if r.concurrency < 2 { + return newTGReader(r.ctx, start, end, chunkSrc) + } + return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc) } diff --git a/internal/reader/tg_multi_reader.go b/internal/reader/tg_multi_reader.go new file mode 100644 index 0000000..51f8f8c --- /dev/null +++ b/internal/reader/tg_multi_reader.go @@ -0,0 +1,281 @@ +package reader + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/divyam234/teldrive/internal/config" + "github.com/divyam234/teldrive/internal/tgc" + "github.com/gotd/td/tg" + "golang.org/x/sync/errgroup" +) + +var ErrorStreamAbandoned = errors.New("stream abandoned") + +type ChunkSource interface { + Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) + ChunkSize(start, end int64) int64 +} + +type chunkSource struct { + channelId int64 + worker *tgc.StreamWorker + fileId string + partId int64 + concurrency int + client *tgc.Client +} + +func (c *chunkSource) ChunkSize(start, end int64) int64 { + return tgc.CalculateChunkSize(start, end) +} + +func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) { + var ( + location *tg.InputDocumentFileLocation + err error + client *tgc.Client + ) + + client = c.client + + defer func() { + if c.concurrency > 0 && client != nil { + defer c.worker.Release(client) + } + }() + + if c.concurrency > 0 { + client, _, _ = c.worker.Next(c.channelId) + } + location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId) + + if err != nil { + return nil, err + } + + return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit) + +} + +type tgMultiReader struct { + ctx context.Context + offset int64 + limit int64 + chunkSize int64 + bufferChan chan *buffer + done chan struct{} + cur *buffer + err chan error + mu sync.Mutex + concurrency int + leftCut int64 + rightCut int64 + totalParts int + currentPart int + closed bool + timeout time.Duration + chunkSrc ChunkSource +} + +func newTGMultiReader( + ctx context.Context, + start int64, + end int64, + config *config.TGConfig, + chunkSrc ChunkSource, + +) (*tgMultiReader, error) { + + chunkSize := chunkSrc.ChunkSize(start, end) + + offset := start - (start % chunkSize) + + r := &tgMultiReader{ + ctx: ctx, + limit: end - start + 1, + bufferChan: make(chan *buffer, config.Stream.Buffers), + concurrency: config.Stream.MultiThreads, + leftCut: start - offset, + rightCut: (end % chunkSize) + 1, + totalParts: int((end - offset + chunkSize) / chunkSize), + offset: offset, + chunkSize: chunkSize, + chunkSrc: chunkSrc, + timeout: config.Stream.ChunkTimeout, + done: make(chan struct{}, 1), + err: make(chan error, 1), + } + + go r.fillBufferConcurrently() + return r, nil +} + +func (r *tgMultiReader) Close() error { + close(r.done) + close(r.bufferChan) + r.closed = true + for b := range r.bufferChan { + if b != nil { + b = nil + } + } + if r.cur != nil { + r.cur = nil + } + close(r.err) + return nil +} + +func (r *tgMultiReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.limit <= 0 { + return 0, io.EOF + } + + if r.cur.isEmpty() { + if r.cur != nil { + r.cur = nil + } + select { + case cur, ok := <-r.bufferChan: + if !ok && r.limit > 0 { + return 0, ErrorStreamAbandoned + } + r.cur = cur + + case err := <-r.err: + return 0, fmt.Errorf("error reading chunk: %w", err) + case <-r.ctx.Done(): + return 0, r.ctx.Err() + + } + } + + n := copy(p, r.cur.buffer()) + r.cur.increment(n) + r.limit -= int64(n) + + if r.limit <= 0 { + return n, io.EOF + } + + return n, nil +} + +func (r *tgMultiReader) fillBufferConcurrently() error { + + var mapMu sync.Mutex + + bufferMap := make(map[int]*buffer) + + defer func() { + + for i := range bufferMap { + delete(bufferMap, i) + } + }() + + cb := func(ctx context.Context, i int) func() error { + return func() error { + + chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize) + if err != nil { + return err + } + if r.totalParts == 1 { + chunk = chunk[r.leftCut:r.rightCut] + } else if r.currentPart+i+1 == 1 { + chunk = chunk[r.leftCut:] + } else if r.currentPart+i+1 == r.totalParts { + chunk = chunk[:r.rightCut] + } + buf := &buffer{buf: chunk} + mapMu.Lock() + bufferMap[i] = buf + mapMu.Unlock() + return nil + } + } + + for { + + g := errgroup.Group{} + + g.SetLimit(r.concurrency) + + for i := range r.concurrency { + if r.currentPart+i+1 <= r.totalParts { + g.Go(cb(r.ctx, i)) + } + } + + done := make(chan error, 1) + + go func() { + done <- g.Wait() + }() + + select { + case err := <-done: + if err != nil { + r.err <- err + return nil + } else { + for i := range r.concurrency { + if r.currentPart+i+1 <= r.totalParts { + select { + case <-r.done: + return nil + case r.bufferChan <- bufferMap[i]: + } + } + } + r.currentPart += r.concurrency + r.offset += r.chunkSize * int64(r.concurrency) + for i := range bufferMap { + delete(bufferMap, i) + } + if r.currentPart >= r.totalParts { + return nil + } + } + case <-time.After(r.timeout): + return nil + case <-r.done: + return nil + case <-r.ctx.Done(): + return r.ctx.Err() + } + + } +} + +type buffer struct { + buf []byte + offset int +} + +func (b *buffer) isEmpty() bool { + if b == nil { + return true + } + if len(b.buf)-b.offset <= 0 { + return true + } + return false +} + +func (b *buffer) buffer() []byte { + return b.buf[b.offset:] +} + +func (b *buffer) increment(n int) { + b.offset += n +} diff --git a/internal/reader/tg_reader_test.go b/internal/reader/tg_multi_reader_test.go similarity index 88% rename from internal/reader/tg_reader_test.go rename to internal/reader/tg_multi_reader_test.go index 0dcc93b..2ff7305 100644 --- a/internal/reader/tg_reader_test.go +++ b/internal/reader/tg_multi_reader_test.go @@ -59,7 +59,7 @@ func (suite *TestSuite) TestFullRead() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSource{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) test_data, err := io.ReadAll(reader) assert.Equal(suite.T(), nil, err) @@ -73,7 +73,7 @@ func (suite *TestSuite) TestPartialRead() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSource{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) test_data, err := io.ReadAll(reader) assert.NoError(suite.T(), err) @@ -87,7 +87,7 @@ func (suite *TestSuite) TestTimeout() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSourceTimeout{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) test_data, err := io.ReadAll(reader) assert.Greater(suite.T(), len(test_data), 0) @@ -101,7 +101,7 @@ func (suite *TestSuite) TestClose() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSource{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) _, err = io.ReadAll(reader) assert.NoError(suite.T(), err) @@ -115,7 +115,7 @@ func (suite *TestSuite) TestCancellation() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSource{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) cancel() _, err = io.ReadAll(reader) @@ -131,7 +131,7 @@ func (suite *TestSuite) TestCancellationWithTimeout() { data := make([]byte, 100) rand.Read(data) chunkSrc := &testChunkSourceTimeout{buffer: data} - reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc) + reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc) assert.NoError(suite.T(), err) _, err = io.ReadAll(reader) assert.Equal(suite.T(), err, context.DeadlineExceeded) diff --git a/internal/reader/tg_reader.go b/internal/reader/tg_reader.go index b19814e..698ee7e 100644 --- a/internal/reader/tg_reader.go +++ b/internal/reader/tg_reader.go @@ -2,94 +2,30 @@ package reader import ( "context" - "errors" - "fmt" "io" - "sync" - "time" - - "github.com/divyam234/teldrive/internal/config" - "github.com/divyam234/teldrive/internal/tgc" - "github.com/gotd/td/tg" - "golang.org/x/sync/errgroup" ) -var ErrorStreamAbandoned = errors.New("stream abandoned") - -type ChunkSource interface { - Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) - ChunkSize(start, end int64) int64 -} - -type chunkSource struct { - channelId int64 - worker *tgc.StreamWorker - fileId string - partId int64 - concurrency int - client *tgc.Client -} - -func (c *chunkSource) ChunkSize(start, end int64) int64 { - return tgc.CalculateChunkSize(start, end) -} - -func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) { - var ( - location *tg.InputDocumentFileLocation - err error - client *tgc.Client - ) - - client = c.client - - defer func() { - if c.concurrency > 0 && client != nil { - defer c.worker.Release(client) - } - }() - - if c.concurrency > 0 { - client, _, _ = c.worker.Next(c.channelId) - } - location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId) - - if err != nil { - return nil, err - } - - return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit) - -} - type tgReader struct { ctx context.Context + cur *buffer offset int64 limit int64 chunkSize int64 - bufferChan chan *buffer - done chan struct{} - cur *buffer - err chan error - mu sync.Mutex - concurrency int leftCut int64 rightCut int64 totalParts int currentPart int - closed bool - timeout time.Duration chunkSrc ChunkSource + err error } func newTGReader( ctx context.Context, start int64, end int64, - config *config.TGConfig, chunkSrc ChunkSource, -) (*tgReader, error) { +) (io.ReadCloser, error) { chunkSize := chunkSrc.ChunkSize(start, end) @@ -97,70 +33,28 @@ func newTGReader( r := &tgReader{ ctx: ctx, - limit: end - start + 1, - bufferChan: make(chan *buffer, config.Stream.Buffers), - concurrency: config.Stream.MultiThreads, leftCut: start - offset, rightCut: (end % chunkSize) + 1, totalParts: int((end - offset + chunkSize) / chunkSize), offset: offset, + limit: end - start + 1, chunkSize: chunkSize, chunkSrc: chunkSrc, - timeout: config.Stream.ChunkTimeout, - done: make(chan struct{}, 1), - err: make(chan error, 1), + currentPart: 1, } - - if r.concurrency == 0 { - r.currentPart = 1 - go r.fillBufferSequentially() - } else { - go r.fillBufferConcurrently() - } - return r, nil } -func (r *tgReader) Close() error { - close(r.done) - close(r.bufferChan) - r.closed = true - for b := range r.bufferChan { - if b != nil { - b = nil - } - } - if r.cur != nil { - r.cur = nil - } - close(r.err) - return nil -} - func (r *tgReader) Read(p []byte) (int, error) { - r.mu.Lock() - defer r.mu.Unlock() if r.limit <= 0 { return 0, io.EOF } if r.cur.isEmpty() { - if r.cur != nil { - r.cur = nil - } - select { - case cur, ok := <-r.bufferChan: - if !ok && r.limit > 0 { - return 0, ErrorStreamAbandoned - } - r.cur = cur - - case err := <-r.err: - return 0, fmt.Errorf("error reading chunk: %w", err) - case <-r.ctx.Done(): - return 0, r.ctx.Err() - + r.cur, r.err = r.next() + if r.err != nil { + return 0, r.err } } @@ -175,157 +69,29 @@ func (r *tgReader) Read(p []byte) (int, error) { return n, nil } -func (r *tgReader) fillBufferConcurrently() error { - - var mapMu sync.Mutex - - bufferMap := make(map[int]*buffer) - - defer func() { - - for i := range bufferMap { - delete(bufferMap, i) - } - }() - - cb := func(ctx context.Context, i int) func() error { - return func() error { - - chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize) - if err != nil { - return err - } - if r.totalParts == 1 { - chunk = chunk[r.leftCut:r.rightCut] - } else if r.currentPart+i+1 == 1 { - chunk = chunk[r.leftCut:] - } else if r.currentPart+i+1 == r.totalParts { - chunk = chunk[:r.rightCut] - } - buf := &buffer{buf: chunk} - mapMu.Lock() - bufferMap[i] = buf - mapMu.Unlock() - return nil - } - } - - for { - - g := errgroup.Group{} - - g.SetLimit(r.concurrency) - - for i := range r.concurrency { - if r.currentPart+i+1 <= r.totalParts { - g.Go(cb(r.ctx, i)) - } - } - - done := make(chan error, 1) - - go func() { - done <- g.Wait() - }() - - select { - case err := <-done: - if err != nil { - r.err <- err - return nil - } else { - for i := range r.concurrency { - if r.currentPart+i+1 <= r.totalParts { - select { - case <-r.done: - return nil - case r.bufferChan <- bufferMap[i]: - } - } - } - r.currentPart += r.concurrency - r.offset += r.chunkSize * int64(r.concurrency) - for i := range bufferMap { - delete(bufferMap, i) - } - if r.currentPart >= r.totalParts { - return nil - } - } - case <-time.After(r.timeout): - return nil - case <-r.done: - return nil - case <-r.ctx.Done(): - return r.ctx.Err() - } - - } +func (*tgReader) Close() error { + return nil } -func (r *tgReader) fillBufferSequentially() error { +func (r *tgReader) next() (*buffer, error) { - fetchChunk := func(ctx context.Context) (*buffer, error) { - chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize) - if err != nil { - return nil, err - } - if r.totalParts == 1 { - chunk = chunk[r.leftCut:r.rightCut] - } else if r.currentPart == 1 { - chunk = chunk[r.leftCut:] - } else if r.currentPart == r.totalParts { - chunk = chunk[:r.rightCut] - } - return &buffer{buf: chunk}, nil + if r.currentPart > r.totalParts { + return nil, io.EOF + } + chunk, err := r.chunkSrc.Chunk(r.ctx, r.offset, r.chunkSize) + if err != nil { + return nil, err + } + if r.totalParts == 1 { + chunk = chunk[r.leftCut:r.rightCut] + } else if r.currentPart == 1 { + chunk = chunk[r.leftCut:] + } else if r.currentPart == r.totalParts { + chunk = chunk[:r.rightCut] } - for { - select { - case <-r.done: - return nil - case <-r.ctx.Done(): - return r.ctx.Err() - case <-time.After(r.timeout): - return nil - default: - buf, err := fetchChunk(r.ctx) - if err != nil { - r.err <- err - return nil - } - if r.closed { - return nil - } - r.bufferChan <- buf - r.currentPart++ - r.offset += r.chunkSize - if r.currentPart > r.totalParts { - return nil - } - } - } -} + r.currentPart++ + r.offset += r.chunkSize + return &buffer{buf: chunk}, nil -type buffer struct { - buf []byte - offset int -} - -func (b *buffer) isEmpty() bool { - if b == nil { - return true - } - if len(b.buf)-b.offset <= 0 { - return true - } - return false -} - -func (b *buffer) buffer() []byte { - return b.buf[b.offset:] -} - -func (b *buffer) increment(n int) { - b.offset += n }