mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-13 18:04:29 +08:00
refactor: reader
This commit is contained in:
parent
04755c76e6
commit
c5cd24bbb3
5 changed files with 323 additions and 269 deletions
|
@ -112,7 +112,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||||
client: r.client, concurrency: r.concurrency}
|
client: r.client, concurrency: r.concurrency}
|
||||||
return newTGReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
|
||||||
|
if r.concurrency < 2 {
|
||||||
|
return newTGReader(r.ctx, start, end, chunkSrc)
|
||||||
|
}
|
||||||
|
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
||||||
|
|
||||||
}, start, end-start+1)
|
}, start, end-start+1)
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,10 @@ func (r *linearReader) nextPart() (io.ReadCloser, error) {
|
||||||
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker,
|
||||||
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID,
|
||||||
client: r.client, concurrency: r.concurrency}
|
client: r.client, concurrency: r.concurrency}
|
||||||
return newTGReader(r.ctx, start, end, r.config, chunkSrc)
|
if r.concurrency < 2 {
|
||||||
|
return newTGReader(r.ctx, start, end, chunkSrc)
|
||||||
|
}
|
||||||
|
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
281
internal/reader/tg_multi_reader.go
Normal file
281
internal/reader/tg_multi_reader.go
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
package reader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/divyam234/teldrive/internal/config"
|
||||||
|
"github.com/divyam234/teldrive/internal/tgc"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||||
|
|
||||||
|
type ChunkSource interface {
|
||||||
|
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
||||||
|
ChunkSize(start, end int64) int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkSource struct {
|
||||||
|
channelId int64
|
||||||
|
worker *tgc.StreamWorker
|
||||||
|
fileId string
|
||||||
|
partId int64
|
||||||
|
concurrency int
|
||||||
|
client *tgc.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
||||||
|
return tgc.CalculateChunkSize(start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
||||||
|
var (
|
||||||
|
location *tg.InputDocumentFileLocation
|
||||||
|
err error
|
||||||
|
client *tgc.Client
|
||||||
|
)
|
||||||
|
|
||||||
|
client = c.client
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if c.concurrency > 0 && client != nil {
|
||||||
|
defer c.worker.Release(client)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if c.concurrency > 0 {
|
||||||
|
client, _, _ = c.worker.Next(c.channelId)
|
||||||
|
}
|
||||||
|
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type tgMultiReader struct {
|
||||||
|
ctx context.Context
|
||||||
|
offset int64
|
||||||
|
limit int64
|
||||||
|
chunkSize int64
|
||||||
|
bufferChan chan *buffer
|
||||||
|
done chan struct{}
|
||||||
|
cur *buffer
|
||||||
|
err chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
concurrency int
|
||||||
|
leftCut int64
|
||||||
|
rightCut int64
|
||||||
|
totalParts int
|
||||||
|
currentPart int
|
||||||
|
closed bool
|
||||||
|
timeout time.Duration
|
||||||
|
chunkSrc ChunkSource
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTGMultiReader(
|
||||||
|
ctx context.Context,
|
||||||
|
start int64,
|
||||||
|
end int64,
|
||||||
|
config *config.TGConfig,
|
||||||
|
chunkSrc ChunkSource,
|
||||||
|
|
||||||
|
) (*tgMultiReader, error) {
|
||||||
|
|
||||||
|
chunkSize := chunkSrc.ChunkSize(start, end)
|
||||||
|
|
||||||
|
offset := start - (start % chunkSize)
|
||||||
|
|
||||||
|
r := &tgMultiReader{
|
||||||
|
ctx: ctx,
|
||||||
|
limit: end - start + 1,
|
||||||
|
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
||||||
|
concurrency: config.Stream.MultiThreads,
|
||||||
|
leftCut: start - offset,
|
||||||
|
rightCut: (end % chunkSize) + 1,
|
||||||
|
totalParts: int((end - offset + chunkSize) / chunkSize),
|
||||||
|
offset: offset,
|
||||||
|
chunkSize: chunkSize,
|
||||||
|
chunkSrc: chunkSrc,
|
||||||
|
timeout: config.Stream.ChunkTimeout,
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
err: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
go r.fillBufferConcurrently()
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgMultiReader) Close() error {
|
||||||
|
close(r.done)
|
||||||
|
close(r.bufferChan)
|
||||||
|
r.closed = true
|
||||||
|
for b := range r.bufferChan {
|
||||||
|
if b != nil {
|
||||||
|
b = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.cur != nil {
|
||||||
|
r.cur = nil
|
||||||
|
}
|
||||||
|
close(r.err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgMultiReader) Read(p []byte) (int, error) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.limit <= 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.cur.isEmpty() {
|
||||||
|
if r.cur != nil {
|
||||||
|
r.cur = nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case cur, ok := <-r.bufferChan:
|
||||||
|
if !ok && r.limit > 0 {
|
||||||
|
return 0, ErrorStreamAbandoned
|
||||||
|
}
|
||||||
|
r.cur = cur
|
||||||
|
|
||||||
|
case err := <-r.err:
|
||||||
|
return 0, fmt.Errorf("error reading chunk: %w", err)
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return 0, r.ctx.Err()
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n := copy(p, r.cur.buffer())
|
||||||
|
r.cur.increment(n)
|
||||||
|
r.limit -= int64(n)
|
||||||
|
|
||||||
|
if r.limit <= 0 {
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tgMultiReader) fillBufferConcurrently() error {
|
||||||
|
|
||||||
|
var mapMu sync.Mutex
|
||||||
|
|
||||||
|
bufferMap := make(map[int]*buffer)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
|
||||||
|
for i := range bufferMap {
|
||||||
|
delete(bufferMap, i)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
cb := func(ctx context.Context, i int) func() error {
|
||||||
|
return func() error {
|
||||||
|
|
||||||
|
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r.totalParts == 1 {
|
||||||
|
chunk = chunk[r.leftCut:r.rightCut]
|
||||||
|
} else if r.currentPart+i+1 == 1 {
|
||||||
|
chunk = chunk[r.leftCut:]
|
||||||
|
} else if r.currentPart+i+1 == r.totalParts {
|
||||||
|
chunk = chunk[:r.rightCut]
|
||||||
|
}
|
||||||
|
buf := &buffer{buf: chunk}
|
||||||
|
mapMu.Lock()
|
||||||
|
bufferMap[i] = buf
|
||||||
|
mapMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
g := errgroup.Group{}
|
||||||
|
|
||||||
|
g.SetLimit(r.concurrency)
|
||||||
|
|
||||||
|
for i := range r.concurrency {
|
||||||
|
if r.currentPart+i+1 <= r.totalParts {
|
||||||
|
g.Go(cb(r.ctx, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
done <- g.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
r.err <- err
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
for i := range r.concurrency {
|
||||||
|
if r.currentPart+i+1 <= r.totalParts {
|
||||||
|
select {
|
||||||
|
case <-r.done:
|
||||||
|
return nil
|
||||||
|
case r.bufferChan <- bufferMap[i]:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.currentPart += r.concurrency
|
||||||
|
r.offset += r.chunkSize * int64(r.concurrency)
|
||||||
|
for i := range bufferMap {
|
||||||
|
delete(bufferMap, i)
|
||||||
|
}
|
||||||
|
if r.currentPart >= r.totalParts {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-time.After(r.timeout):
|
||||||
|
return nil
|
||||||
|
case <-r.done:
|
||||||
|
return nil
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return r.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type buffer struct {
|
||||||
|
buf []byte
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) isEmpty() bool {
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(b.buf)-b.offset <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) buffer() []byte {
|
||||||
|
return b.buf[b.offset:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *buffer) increment(n int) {
|
||||||
|
b.offset += n
|
||||||
|
}
|
|
@ -59,7 +59,7 @@ func (suite *TestSuite) TestFullRead() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSource{buffer: data}
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
test_data, err := io.ReadAll(reader)
|
test_data, err := io.ReadAll(reader)
|
||||||
assert.Equal(suite.T(), nil, err)
|
assert.Equal(suite.T(), nil, err)
|
||||||
|
@ -73,7 +73,7 @@ func (suite *TestSuite) TestPartialRead() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSource{buffer: data}
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
test_data, err := io.ReadAll(reader)
|
test_data, err := io.ReadAll(reader)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
@ -87,7 +87,7 @@ func (suite *TestSuite) TestTimeout() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
test_data, err := io.ReadAll(reader)
|
test_data, err := io.ReadAll(reader)
|
||||||
assert.Greater(suite.T(), len(test_data), 0)
|
assert.Greater(suite.T(), len(test_data), 0)
|
||||||
|
@ -101,7 +101,7 @@ func (suite *TestSuite) TestClose() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSource{buffer: data}
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
_, err = io.ReadAll(reader)
|
_, err = io.ReadAll(reader)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
@ -115,7 +115,7 @@ func (suite *TestSuite) TestCancellation() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSource{buffer: data}
|
chunkSrc := &testChunkSource{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
cancel()
|
cancel()
|
||||||
_, err = io.ReadAll(reader)
|
_, err = io.ReadAll(reader)
|
||||||
|
@ -131,7 +131,7 @@ func (suite *TestSuite) TestCancellationWithTimeout() {
|
||||||
data := make([]byte, 100)
|
data := make([]byte, 100)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
chunkSrc := &testChunkSourceTimeout{buffer: data}
|
||||||
reader, err := newTGReader(ctx, start, end, suite.config, chunkSrc)
|
reader, err := newTGMultiReader(ctx, start, end, suite.config, chunkSrc)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
_, err = io.ReadAll(reader)
|
_, err = io.ReadAll(reader)
|
||||||
assert.Equal(suite.T(), err, context.DeadlineExceeded)
|
assert.Equal(suite.T(), err, context.DeadlineExceeded)
|
|
@ -2,94 +2,30 @@ package reader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/divyam234/teldrive/internal/config"
|
|
||||||
"github.com/divyam234/teldrive/internal/tgc"
|
|
||||||
"github.com/gotd/td/tg"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrorStreamAbandoned = errors.New("stream abandoned")
|
|
||||||
|
|
||||||
type ChunkSource interface {
|
|
||||||
Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error)
|
|
||||||
ChunkSize(start, end int64) int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type chunkSource struct {
|
|
||||||
channelId int64
|
|
||||||
worker *tgc.StreamWorker
|
|
||||||
fileId string
|
|
||||||
partId int64
|
|
||||||
concurrency int
|
|
||||||
client *tgc.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chunkSource) ChunkSize(start, end int64) int64 {
|
|
||||||
return tgc.CalculateChunkSize(start, end)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]byte, error) {
|
|
||||||
var (
|
|
||||||
location *tg.InputDocumentFileLocation
|
|
||||||
err error
|
|
||||||
client *tgc.Client
|
|
||||||
)
|
|
||||||
|
|
||||||
client = c.client
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if c.concurrency > 0 && client != nil {
|
|
||||||
defer c.worker.Release(client)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if c.concurrency > 0 {
|
|
||||||
client, _, _ = c.worker.Next(c.channelId)
|
|
||||||
}
|
|
||||||
location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tgc.GetChunk(ctx, client.Tg.API(), location, offset, limit)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type tgReader struct {
|
type tgReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cur *buffer
|
||||||
offset int64
|
offset int64
|
||||||
limit int64
|
limit int64
|
||||||
chunkSize int64
|
chunkSize int64
|
||||||
bufferChan chan *buffer
|
|
||||||
done chan struct{}
|
|
||||||
cur *buffer
|
|
||||||
err chan error
|
|
||||||
mu sync.Mutex
|
|
||||||
concurrency int
|
|
||||||
leftCut int64
|
leftCut int64
|
||||||
rightCut int64
|
rightCut int64
|
||||||
totalParts int
|
totalParts int
|
||||||
currentPart int
|
currentPart int
|
||||||
closed bool
|
|
||||||
timeout time.Duration
|
|
||||||
chunkSrc ChunkSource
|
chunkSrc ChunkSource
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTGReader(
|
func newTGReader(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
start int64,
|
start int64,
|
||||||
end int64,
|
end int64,
|
||||||
config *config.TGConfig,
|
|
||||||
chunkSrc ChunkSource,
|
chunkSrc ChunkSource,
|
||||||
|
|
||||||
) (*tgReader, error) {
|
) (io.ReadCloser, error) {
|
||||||
|
|
||||||
chunkSize := chunkSrc.ChunkSize(start, end)
|
chunkSize := chunkSrc.ChunkSize(start, end)
|
||||||
|
|
||||||
|
@ -97,70 +33,28 @@ func newTGReader(
|
||||||
|
|
||||||
r := &tgReader{
|
r := &tgReader{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
limit: end - start + 1,
|
|
||||||
bufferChan: make(chan *buffer, config.Stream.Buffers),
|
|
||||||
concurrency: config.Stream.MultiThreads,
|
|
||||||
leftCut: start - offset,
|
leftCut: start - offset,
|
||||||
rightCut: (end % chunkSize) + 1,
|
rightCut: (end % chunkSize) + 1,
|
||||||
totalParts: int((end - offset + chunkSize) / chunkSize),
|
totalParts: int((end - offset + chunkSize) / chunkSize),
|
||||||
offset: offset,
|
offset: offset,
|
||||||
|
limit: end - start + 1,
|
||||||
chunkSize: chunkSize,
|
chunkSize: chunkSize,
|
||||||
chunkSrc: chunkSrc,
|
chunkSrc: chunkSrc,
|
||||||
timeout: config.Stream.ChunkTimeout,
|
currentPart: 1,
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
err: make(chan error, 1),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.concurrency == 0 {
|
|
||||||
r.currentPart = 1
|
|
||||||
go r.fillBufferSequentially()
|
|
||||||
} else {
|
|
||||||
go r.fillBufferConcurrently()
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tgReader) Close() error {
|
|
||||||
close(r.done)
|
|
||||||
close(r.bufferChan)
|
|
||||||
r.closed = true
|
|
||||||
for b := range r.bufferChan {
|
|
||||||
if b != nil {
|
|
||||||
b = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if r.cur != nil {
|
|
||||||
r.cur = nil
|
|
||||||
}
|
|
||||||
close(r.err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *tgReader) Read(p []byte) (int, error) {
|
func (r *tgReader) Read(p []byte) (int, error) {
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.limit <= 0 {
|
if r.limit <= 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.cur.isEmpty() {
|
if r.cur.isEmpty() {
|
||||||
if r.cur != nil {
|
r.cur, r.err = r.next()
|
||||||
r.cur = nil
|
if r.err != nil {
|
||||||
}
|
return 0, r.err
|
||||||
select {
|
|
||||||
case cur, ok := <-r.bufferChan:
|
|
||||||
if !ok && r.limit > 0 {
|
|
||||||
return 0, ErrorStreamAbandoned
|
|
||||||
}
|
|
||||||
r.cur = cur
|
|
||||||
|
|
||||||
case err := <-r.err:
|
|
||||||
return 0, fmt.Errorf("error reading chunk: %w", err)
|
|
||||||
case <-r.ctx.Done():
|
|
||||||
return 0, r.ctx.Err()
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,157 +69,29 @@ func (r *tgReader) Read(p []byte) (int, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tgReader) fillBufferConcurrently() error {
|
func (*tgReader) Close() error {
|
||||||
|
return nil
|
||||||
var mapMu sync.Mutex
|
|
||||||
|
|
||||||
bufferMap := make(map[int]*buffer)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
|
|
||||||
for i := range bufferMap {
|
|
||||||
delete(bufferMap, i)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
cb := func(ctx context.Context, i int) func() error {
|
|
||||||
return func() error {
|
|
||||||
|
|
||||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset+(int64(i)*r.chunkSize), r.chunkSize)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if r.totalParts == 1 {
|
|
||||||
chunk = chunk[r.leftCut:r.rightCut]
|
|
||||||
} else if r.currentPart+i+1 == 1 {
|
|
||||||
chunk = chunk[r.leftCut:]
|
|
||||||
} else if r.currentPart+i+1 == r.totalParts {
|
|
||||||
chunk = chunk[:r.rightCut]
|
|
||||||
}
|
|
||||||
buf := &buffer{buf: chunk}
|
|
||||||
mapMu.Lock()
|
|
||||||
bufferMap[i] = buf
|
|
||||||
mapMu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
|
|
||||||
g := errgroup.Group{}
|
|
||||||
|
|
||||||
g.SetLimit(r.concurrency)
|
|
||||||
|
|
||||||
for i := range r.concurrency {
|
|
||||||
if r.currentPart+i+1 <= r.totalParts {
|
|
||||||
g.Go(cb(r.ctx, i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
done <- g.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
if err != nil {
|
|
||||||
r.err <- err
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
for i := range r.concurrency {
|
|
||||||
if r.currentPart+i+1 <= r.totalParts {
|
|
||||||
select {
|
|
||||||
case <-r.done:
|
|
||||||
return nil
|
|
||||||
case r.bufferChan <- bufferMap[i]:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.currentPart += r.concurrency
|
|
||||||
r.offset += r.chunkSize * int64(r.concurrency)
|
|
||||||
for i := range bufferMap {
|
|
||||||
delete(bufferMap, i)
|
|
||||||
}
|
|
||||||
if r.currentPart >= r.totalParts {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-time.After(r.timeout):
|
|
||||||
return nil
|
|
||||||
case <-r.done:
|
|
||||||
return nil
|
|
||||||
case <-r.ctx.Done():
|
|
||||||
return r.ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tgReader) fillBufferSequentially() error {
|
func (r *tgReader) next() (*buffer, error) {
|
||||||
|
|
||||||
fetchChunk := func(ctx context.Context) (*buffer, error) {
|
if r.currentPart > r.totalParts {
|
||||||
chunk, err := r.chunkSrc.Chunk(ctx, r.offset, r.chunkSize)
|
return nil, io.EOF
|
||||||
if err != nil {
|
}
|
||||||
return nil, err
|
chunk, err := r.chunkSrc.Chunk(r.ctx, r.offset, r.chunkSize)
|
||||||
}
|
if err != nil {
|
||||||
if r.totalParts == 1 {
|
return nil, err
|
||||||
chunk = chunk[r.leftCut:r.rightCut]
|
}
|
||||||
} else if r.currentPart == 1 {
|
if r.totalParts == 1 {
|
||||||
chunk = chunk[r.leftCut:]
|
chunk = chunk[r.leftCut:r.rightCut]
|
||||||
} else if r.currentPart == r.totalParts {
|
} else if r.currentPart == 1 {
|
||||||
chunk = chunk[:r.rightCut]
|
chunk = chunk[r.leftCut:]
|
||||||
}
|
} else if r.currentPart == r.totalParts {
|
||||||
return &buffer{buf: chunk}, nil
|
chunk = chunk[:r.rightCut]
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
r.currentPart++
|
||||||
select {
|
r.offset += r.chunkSize
|
||||||
case <-r.done:
|
return &buffer{buf: chunk}, nil
|
||||||
return nil
|
|
||||||
case <-r.ctx.Done():
|
|
||||||
return r.ctx.Err()
|
|
||||||
case <-time.After(r.timeout):
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
buf, err := fetchChunk(r.ctx)
|
|
||||||
if err != nil {
|
|
||||||
r.err <- err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if r.closed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
r.bufferChan <- buf
|
|
||||||
r.currentPart++
|
|
||||||
r.offset += r.chunkSize
|
|
||||||
if r.currentPart > r.totalParts {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type buffer struct {
|
|
||||||
buf []byte
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) isEmpty() bool {
|
|
||||||
if b == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(b.buf)-b.offset <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) buffer() []byte {
|
|
||||||
return b.buf[b.offset:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *buffer) increment(n int) {
|
|
||||||
b.offset += n
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue