refactor: reader

This commit is contained in:
divyam234 2024-06-29 15:07:54 +05:30
parent 04755c76e6
commit c5cd24bbb3
5 changed files with 323 additions and 269 deletions

View file

@ -112,7 +112,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
}
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1)

View file

@ -117,7 +117,10 @@ func (r *linearReader) nextPart() (io.ReadCloser, error) {
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
client: r.client, concurrency: r.concurrency}
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
}
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
}

View file

@ -0,0 +1,281 @@
package reader
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgMultiReader struct {
ctx context.Context
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
}
func newTGMultiReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgMultiReader, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
offset := start - (start % chunkSize)
r := &tgMultiReader{
ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
}
go r.fillBufferConcurrently()
return r, nil
}
func (r *tgMultiReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
return nil
}
func (r *tgMultiReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
n := copy(p, r.cur.buffer())
r.cur.increment(n)
r.limit -= int64(n)
if r.limit <= 0 {
return n, io.EOF
}
return n, nil
}
func (r *tgMultiReader) fillBufferConcurrently() error {
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
}
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}

View file

@ -59,7 +59,7 @@ func (suite *TestSuite) TestFullRead() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.Equal(suite.T(), nil, err)
@ -73,7 +73,7 @@ func (suite *TestSuite) TestPartialRead() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.NoError(suite.T(), err)
@ -87,7 +87,7 @@ func (suite *TestSuite) TestTimeout() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
test_data, err := io.ReadAll(reader)
assert.Greater(suite.T(), len(test_data), 0)
@ -101,7 +101,7 @@ func (suite *TestSuite) TestClose() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader)
assert.NoError(suite.T(), err)
@ -115,7 +115,7 @@ func (suite *TestSuite) TestCancellation() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSource{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
cancel()
_, err = io.ReadAll(reader)
@ -131,7 +131,7 @@ func (suite *TestSuite) TestCancellationWithTimeout() {
data := make([]byte, 100)
rand.Read(data)
chunkSrc := &testChunkSourceTimeout{buffer: data}
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
assert.NoError(suite.T(), err)
_, err = io.ReadAll(reader)
assert.Equal(suite.T(), err, context.DeadlineExceeded)

View file

@ -2,94 +2,30 @@ package reader
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/gotd/td/tg"
"golang.org/x/sync/errgroup"
)
var ErrorStreamAbandoned = errors.New("stream abandoned")
type ChunkSource interface {
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
ChunkSize(start, end int64) int64
}
type chunkSource struct {
channelId int64
worker *tgc.StreamWorker
fileId string
partId int64
concurrency int
client *tgc.Client
}
func (c *chunkSource) ChunkSize(start, end int64) int64 {
return tgc.CalculateChunkSize(start, end)
}
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
var (
location *tg.InputDocumentFileLocation
err error
client *tgc.Client
)
client = c.client
defer func() {
if c.concurrency > 0 && client != nil {
defer c.worker.Release(client)
}
}()
if c.concurrency > 0 {
client, _, _ = c.worker.Next(c.channelId)
}
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
if err != nil {
return nil, err
}
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
}
type tgReader struct {
ctx context.Context
cur *buffer
offset int64
limit int64
chunkSize int64
bufferChan chan *buffer
done chan struct{}
cur *buffer
err chan error
mu sync.Mutex
concurrency int
leftCut int64
rightCut int64
totalParts int
currentPart int
closed bool
timeout time.Duration
chunkSrc ChunkSource
err error
}
func newTGReader(
ctx context.Context,
start int64,
end int64,
config *config.TGConfig,
chunkSrc ChunkSource,
) (*tgReader, error) {
) (io.ReadCloser, error) {
chunkSize := chunkSrc.ChunkSize(start, end)
@ -97,70 +33,28 @@ func newTGReader(
r := &tgReader{
ctx: ctx,
limit: end - start + 1,
bufferChan: make(chan *buffer, config.Stream.Buffers),
concurrency: config.Stream.MultiThreads,
leftCut: start - offset,
rightCut: (end % chunkSize) + 1,
totalParts: int((end - offset + chunkSize) / chunkSize),
offset: offset,
limit: end - start + 1,
chunkSize: chunkSize,
chunkSrc: chunkSrc,
timeout: config.Stream.ChunkTimeout,
done: make(chan struct{}, 1),
err: make(chan error, 1),
currentPart: 1,
}
if r.concurrency == 0 {
r.currentPart = 1
go r.fillBufferSequentially()
} else {
go r.fillBufferConcurrently()
}
return r, nil
}
func (r *tgReader) Close() error {
close(r.done)
close(r.bufferChan)
r.closed = true
for b := range r.bufferChan {
if b != nil {
b = nil
}
}
if r.cur != nil {
r.cur = nil
}
close(r.err)
return nil
}
func (r *tgReader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.limit <= 0 {
return 0, io.EOF
}
if r.cur.isEmpty() {
if r.cur != nil {
r.cur = nil
}
select {
case cur, ok := <-r.bufferChan:
if !ok && r.limit > 0 {
return 0, ErrorStreamAbandoned
}
r.cur = cur
case err := <-r.err:
return 0, fmt.Errorf("error reading chunk: %w", err)
case <-r.ctx.Done():
return 0, r.ctx.Err()
r.cur, r.err = r.next()
if r.err != nil {
return 0, r.err
}
}
@ -175,157 +69,29 @@ func (r *tgReader) Read(p []byte) (int, error) {
return n, nil
}
func (r *tgReader) fillBufferConcurrently() error {
var mapMu sync.Mutex
bufferMap := make(map[int]*buffer)
defer func() {
for i := range bufferMap {
delete(bufferMap, i)
}
}()
cb := func(ctx context.Context, i int) func() error {
return func() error {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
if err != nil {
return err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart+i+1 == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart+i+1 == r.totalParts {
chunk = chunk[:r.rightCut]
}
buf := &buffer{buf: chunk}
mapMu.Lock()
bufferMap[i] = buf
mapMu.Unlock()
return nil
}
}
for {
g := errgroup.Group{}
g.SetLimit(r.concurrency)
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
g.Go(cb(r.ctx, i))
}
}
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
select {
case err := <-done:
if err != nil {
r.err <- err
return nil
} else {
for i := range r.concurrency {
if r.currentPart+i+1 <= r.totalParts {
select {
case <-r.done:
return nil
case r.bufferChan <- bufferMap[i]:
}
}
}
r.currentPart += r.concurrency
r.offset += r.chunkSize * int64(r.concurrency)
for i := range bufferMap {
delete(bufferMap, i)
}
if r.currentPart >= r.totalParts {
return nil
}
}
case <-time.After(r.timeout):
return nil
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
}
}
func (*tgReader) Close() error {
return nil
}
func (r *tgReader) fillBufferSequentially() error {
func (r *tgReader) next() (*buffer, error) {
fetchChunk := func(ctx context.Context) (*buffer, error) {
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize)
if err != nil {
return nil, err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart == r.totalParts {
chunk = chunk[:r.rightCut]
}
return &buffer{buf: chunk}, nil
if r.currentPart > r.totalParts {
return nil, io.EOF
}
chunk, err := r.chunkSrc.Chunk(r.ctx, r.offset, r.chunkSize)
if err != nil {
return nil, err
}
if r.totalParts == 1 {
chunk = chunk[r.leftCut:r.rightCut]
} else if r.currentPart == 1 {
chunk = chunk[r.leftCut:]
} else if r.currentPart == r.totalParts {
chunk = chunk[:r.rightCut]
}
for {
select {
case <-r.done:
return nil
case <-r.ctx.Done():
return r.ctx.Err()
case <-time.After(r.timeout):
return nil
default:
buf, err := fetchChunk(r.ctx)
if err != nil {
r.err <- err
return nil
}
if r.closed {
return nil
}
r.bufferChan <- buf
r.currentPart++
r.offset += r.chunkSize
if r.currentPart > r.totalParts {
return nil
}
}
}
}
r.currentPart++
r.offset += r.chunkSize
return &buffer{buf: chunk}, nil
type buffer struct {
buf []byte
offset int
}
func (b *buffer) isEmpty() bool {
if b == nil {
return true
}
if len(b.buf)-b.offset <= 0 {
return true
}
return false
}
func (b *buffer) buffer() []byte {
return b.buf[b.offset:]
}
func (b *buffer) increment(n int) {
b.offset += n
}