teldrive/internal/reader/reader.go
2024-08-31 15:50:17 +00:00

148 lines
3 KiB
Go

package reader
import (
"context"
"fmt"
"io"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/pkg/schemas"
"github.com/tgdrive/teldrive/pkg/types"
)
type Range struct {
Start, End int64
PartNo int64
}
type LinearReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []Range
pos int
reader io.ReadCloser
remaining int64
config *config.TGConfig
worker *tgc.StreamWorker
client *tg.Client
concurrency int
cache cache.Cacher
}
func calculatePartByteRanges(start, end, partSize int64) []Range {
ranges := make([]Range, 0)
startPart := start / partSize
endPart := end / partSize
for part := startPart; part <= endPart; part++ {
partStart := max(start-part*partSize, 0)
partEnd := min(partSize-1, end-part*partSize)
ranges = append(ranges, Range{
Start: partStart,
End: partEnd,
PartNo: part,
})
}
return ranges
}
func NewLinearReader(ctx context.Context,
client *tg.Client,
worker *tgc.StreamWorker,
cache cache.Cacher,
file *schemas.FileOutFull,
parts []types.Part,
start,
end int64,
config *config.TGConfig,
concurrency int,
) (io.ReadCloser, error) {
r := &LinearReader{
ctx: ctx,
parts: parts,
file: file,
remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config,
client: client,
worker: worker,
concurrency: concurrency,
cache: cache,
}
if err := r.initializeReader(); err != nil {
return nil, err
}
return r, nil
}
func (r *LinearReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, io.EOF
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF && r.remaining > 0 {
if err := r.moveToNextPart(); err != nil {
return n, err
}
err = nil
}
return n, err
}
func (r *LinearReader) Close() error {
if r.reader != nil {
err := r.reader.Close()
r.reader = nil
return err
}
return nil
}
func (r *LinearReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
func (r *LinearReader) moveToNextPart() error {
r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
if r.concurrency < 2 {
return newTGReader(r.ctx, currentRange.Start, currentRange.End, chunkSrc)
}
return newTGMultiReader(r.ctx, currentRange.Start, currentRange.End, r.config, chunkSrc)
}