refactor: readers

This commit is contained in:
divyam234 2024-08-12 14:32:33 +05:30
parent a40304eeed
commit 619fc3eccd
3 changed files with 110 additions and 109 deletions

View file

@ -14,14 +14,14 @@ import (
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
) )
type decrpytedReader struct { type DecrpytedReader struct {
ctx context.Context ctx context.Context
file *schemas.FileOutFull file *schemas.FileOutFull
parts []types.Part parts []types.Part
ranges []types.Range ranges []Range
pos int pos int
reader io.ReadCloser reader io.ReadCloser
limit int64 remaining int64
config *config.TGConfig config *config.TGConfig
worker *tgc.StreamWorker worker *tgc.StreamWorker
client *tg.Client client *tg.Client
@ -40,13 +40,13 @@ func NewDecryptedReader(
end int64, end int64,
config *config.TGConfig, config *config.TGConfig,
concurrency int, concurrency int,
) (*decrpytedReader, error) { ) (*DecrpytedReader, error) {
r := &decrpytedReader{ r := &DecrpytedReader{
ctx: ctx, ctx: ctx,
parts: parts, parts: parts,
file: file, file: file,
limit: end - start + 1, remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize), ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config, config: config,
client: client, client: client,
@ -54,57 +54,73 @@ func NewDecryptedReader(
concurrency: concurrency, concurrency: concurrency,
cache: cache, cache: cache,
} }
res, err := r.nextPart() if err := r.initializeReader(); err != nil {
if err != nil {
return nil, err return nil, err
} }
r.reader = res
return r, nil return r, nil
} }
func (r *decrpytedReader) Read(p []byte) (int, error) { func (r *DecrpytedReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
if r.limit <= 0 {
return 0, io.EOF return 0, io.EOF
} }
n, err := r.reader.Read(p) n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF { if err == io.EOF && r.remaining > 0 {
if r.limit > 0 { if err := r.moveToNextPart(); err != nil {
err = nil return n, err
if r.reader != nil {
r.reader.Close()
}
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
} }
err = nil
} }
r.limit -= int64(n)
return n, err return n, err
} }
func (r *decrpytedReader) Close() (err error) {
func (r *DecrpytedReader) Close() error {
if r.reader != nil { if r.reader != nil {
err = r.reader.Close() err := r.reader.Close()
r.reader = nil r.reader = nil
return err return err
} }
return nil return nil
} }
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { func (r *DecrpytedReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
start := r.ranges[r.pos].Start func (r *DecrpytedReader) moveToNextPart() error {
end := r.ranges[r.pos].End r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *DecrpytedReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
salt := r.parts[r.ranges[r.pos].PartNo].Salt salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt) cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
return cipher.DecryptDataSeek(r.ctx, return cipher.DecryptDataSeek(r.ctx,
func(ctx context.Context, func(ctx context.Context,
@ -115,22 +131,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
if underlyingLimit >= 0 { if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1) end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
} }
partID := r.parts[r.ranges[r.pos].PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
if r.concurrency < 2 { if r.concurrency < 2 {
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc) return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
} }
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc) return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1) }, currentRange.Start, currentRange.End-currentRange.Start+1)
} }

View file

@ -13,43 +13,19 @@ import (
"github.com/gotd/td/tg" "github.com/gotd/td/tg"
) )
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range { type Range struct {
partByteRanges := []types.Range{} Start, End int64
startPart := startByte / partSize PartNo int64
endPart := endByte / partSize
startOffset := startByte % partSize
for part := startPart; part <= endPart; part++ {
partStartByte := int64(0)
partEndByte := partSize - 1
if part == startPart {
partStartByte = startOffset
}
if part == endPart {
partEndByte = endByte % partSize
}
partByteRanges = append(partByteRanges, types.Range{
Start: partStartByte,
End: partEndByte,
PartNo: part,
})
startOffset = 0
}
return partByteRanges
} }
type LinearReader struct { type LinearReader struct {
ctx context.Context ctx context.Context
file *schemas.FileOutFull file *schemas.FileOutFull
parts []types.Part parts []types.Part
ranges []types.Range ranges []Range
pos int pos int
reader io.ReadCloser reader io.ReadCloser
limit int64 remaining int64
config *config.TGConfig config *config.TGConfig
worker *tgc.StreamWorker worker *tgc.StreamWorker
client *tg.Client client *tg.Client
@ -57,6 +33,23 @@ type LinearReader struct {
cache cache.Cacher cache cache.Cacher
} }
func calculatePartByteRanges(start, end, partSize int64) []Range {
ranges := make([]Range, 0)
startPart := start / partSize
endPart := end / partSize
for part := startPart; part <= endPart; part++ {
partStart := max(start-part*partSize, 0)
partEnd := min(partSize-1, end-part*partSize)
ranges = append(ranges, Range{
Start: partStart,
End: partEnd,
PartNo: part,
})
}
return ranges
}
func NewLinearReader(ctx context.Context, func NewLinearReader(ctx context.Context,
client *tg.Client, client *tg.Client,
worker *tgc.StreamWorker, worker *tgc.StreamWorker,
@ -73,7 +66,7 @@ func NewLinearReader(ctx context.Context,
ctx: ctx, ctx: ctx,
parts: parts, parts: parts,
file: file, file: file,
limit: end - start + 1, remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size), ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config, config: config,
client: client, client: client,
@ -82,42 +75,60 @@ func NewLinearReader(ctx context.Context,
cache: cache, cache: cache,
} }
var err error if err := r.initializeReader(); err != nil {
r.reader, err = r.nextPart()
if err != nil {
return nil, err return nil, err
} }
return r, nil return r, nil
} }
func (r *LinearReader) Read(p []byte) (int, error) { func (r *LinearReader) Read(p []byte) (int, error) {
if r.limit <= 0 { if r.remaining <= 0 {
return 0, io.EOF return 0, io.EOF
} }
n, err := r.reader.Read(p) n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF && r.limit > 0 { if err == io.EOF && r.remaining > 0 {
if err := r.moveToNextPart(); err != nil {
return n, err
}
err = nil err = nil
if r.reader != nil {
r.reader.Close()
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
}
} }
r.limit -= int64(n)
return n, err return n, err
} }
func (r *LinearReader) nextPart() (io.ReadCloser, error) { func (r *LinearReader) Close() error {
start := r.ranges[r.pos].Start if r.reader != nil {
end := r.ranges[r.pos].End err := r.reader.Close()
r.reader = nil
return err
}
return nil
}
partID := r.parts[r.ranges[r.pos].PartNo].ID func (r *LinearReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
func (r *LinearReader) moveToNextPart() error {
r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{ chunkSrc := &chunkSource{
channelID: r.file.ChannelID, channelID: r.file.ChannelID,
@ -130,16 +141,7 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
} }
if r.concurrency < 2 { if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc) return newTGReader(r.ctx, currentRange.Start, currentRange.End, chunkSrc)
} }
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc) return newTGMultiReader(r.ctx, currentRange.Start, currentRange.End, r.config, chunkSrc)
}
func (r *LinearReader) Close() error {
if r.reader != nil {
err := r.reader.Close()
r.reader = nil
return err
}
return nil
} }

View file

@ -46,9 +46,3 @@ type BotInfo struct {
AccessHash int64 AccessHash int64
Token string Token string
} }
type Range struct {
Start int64
End int64
PartNo int64
}