refactor: readers

This commit is contained in:
divyam234 2024-08-12 14:32:33 +05:30
parent a40304eeed
commit 619fc3eccd
3 changed files with 110 additions and 109 deletions

View file

@ -14,14 +14,14 @@ import (
"github.com/gotd/td/tg"
)
type decrpytedReader struct {
type DecrpytedReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []types.Range
ranges []Range
pos int
reader io.ReadCloser
limit int64
remaining int64
config *config.TGConfig
worker *tgc.StreamWorker
client *tg.Client
@ -40,13 +40,13 @@ func NewDecryptedReader(
end int64,
config *config.TGConfig,
concurrency int,
) (*decrpytedReader, error) {
) (*DecrpytedReader, error) {
r := &decrpytedReader{
r := &DecrpytedReader{
ctx: ctx,
parts: parts,
file: file,
limit: end - start + 1,
remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config,
client: client,
@ -54,57 +54,73 @@ func NewDecryptedReader(
concurrency: concurrency,
cache: cache,
}
res, err := r.nextPart()
if err != nil {
if err := r.initializeReader(); err != nil {
return nil, err
}
r.reader = res
return r, nil
}
func (r *decrpytedReader) Read(p []byte) (int, error) {
if r.limit <= 0 {
func (r *DecrpytedReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, io.EOF
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF {
if r.limit > 0 {
err = nil
if r.reader != nil {
r.reader.Close()
}
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
if err == io.EOF && r.remaining > 0 {
if err := r.moveToNextPart(); err != nil {
return n, err
}
err = nil
}
r.limit -= int64(n)
return n, err
}
func (r *decrpytedReader) Close() (err error) {
func (r *DecrpytedReader) Close() error {
if r.reader != nil {
err = r.reader.Close()
err := r.reader.Close()
r.reader = nil
return err
}
return nil
}
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
func (r *DecrpytedReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
func (r *DecrpytedReader) moveToNextPart() error {
r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *DecrpytedReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
return cipher.DecryptDataSeek(r.ctx,
func(ctx context.Context,
@ -115,22 +131,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
partID := r.parts[r.ranges[r.pos].PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
if r.concurrency < 2 {
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
}
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, start, end-start+1)
}, currentRange.Start, currentRange.End-currentRange.Start+1)
}

View file

@ -13,43 +13,19 @@ import (
"github.com/gotd/td/tg"
)
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
partByteRanges := []types.Range{}
startPart := startByte / partSize
endPart := endByte / partSize
startOffset := startByte % partSize
for part := startPart; part <= endPart; part++ {
partStartByte := int64(0)
partEndByte := partSize - 1
if part == startPart {
partStartByte = startOffset
}
if part == endPart {
partEndByte = endByte % partSize
}
partByteRanges = append(partByteRanges, types.Range{
Start: partStartByte,
End: partEndByte,
PartNo: part,
})
startOffset = 0
}
return partByteRanges
type Range struct {
Start, End int64
PartNo int64
}
type LinearReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []types.Range
ranges []Range
pos int
reader io.ReadCloser
limit int64
remaining int64
config *config.TGConfig
worker *tgc.StreamWorker
client *tg.Client
@ -57,6 +33,23 @@ type LinearReader struct {
cache cache.Cacher
}
func calculatePartByteRanges(start, end, partSize int64) []Range {
ranges := make([]Range, 0)
startPart := start / partSize
endPart := end / partSize
for part := startPart; part <= endPart; part++ {
partStart := max(start-part*partSize, 0)
partEnd := min(partSize-1, end-part*partSize)
ranges = append(ranges, Range{
Start: partStart,
End: partEnd,
PartNo: part,
})
}
return ranges
}
func NewLinearReader(ctx context.Context,
client *tg.Client,
worker *tgc.StreamWorker,
@ -73,7 +66,7 @@ func NewLinearReader(ctx context.Context,
ctx: ctx,
parts: parts,
file: file,
limit: end - start + 1,
remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].Size),
config: config,
client: client,
@ -82,42 +75,60 @@ func NewLinearReader(ctx context.Context,
cache: cache,
}
var err error
r.reader, err = r.nextPart()
if err != nil {
if err := r.initializeReader(); err != nil {
return nil, err
}
return r, nil
}
func (r *LinearReader) Read(p []byte) (int, error) {
if r.limit <= 0 {
if r.remaining <= 0 {
return 0, io.EOF
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF && r.limit > 0 {
if err == io.EOF && r.remaining > 0 {
if err := r.moveToNextPart(); err != nil {
return n, err
}
err = nil
if r.reader != nil {
r.reader.Close()
}
r.pos++
if r.pos < len(r.ranges) {
r.reader, err = r.nextPart()
}
}
r.limit -= int64(n)
return n, err
}
func (r *LinearReader) nextPart() (io.ReadCloser, error) {
start := r.ranges[r.pos].Start
end := r.ranges[r.pos].End
func (r *LinearReader) Close() error {
if r.reader != nil {
err := r.reader.Close()
r.reader = nil
return err
}
return nil
}
partID := r.parts[r.ranges[r.pos].PartNo].ID
func (r *LinearReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
func (r *LinearReader) moveToNextPart() error {
r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
@ -130,16 +141,7 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
}
if r.concurrency < 2 {
return newTGReader(r.ctx, start, end, chunkSrc)
return newTGReader(r.ctx, currentRange.Start, currentRange.End, chunkSrc)
}
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
}
func (r *LinearReader) Close() error {
if r.reader != nil {
err := r.reader.Close()
r.reader = nil
return err
}
return nil
return newTGMultiReader(r.ctx, currentRange.Start, currentRange.End, r.config, chunkSrc)
}

View file

@ -46,9 +46,3 @@ type BotInfo struct {
AccessHash int64
Token string
}
type Range struct {
Start int64
End int64
PartNo int64
}