mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-02-24 06:55:02 +08:00
refactor: readers
This commit is contained in:
parent
a40304eeed
commit
619fc3eccd
3 changed files with 110 additions and 109 deletions
|
@ -14,14 +14,14 @@ import (
|
|||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
type decrpytedReader struct {
|
||||
type DecrpytedReader struct {
|
||||
ctx context.Context
|
||||
file *schemas.FileOutFull
|
||||
parts []types.Part
|
||||
ranges []types.Range
|
||||
ranges []Range
|
||||
pos int
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
remaining int64
|
||||
config *config.TGConfig
|
||||
worker *tgc.StreamWorker
|
||||
client *tg.Client
|
||||
|
@ -40,13 +40,13 @@ func NewDecryptedReader(
|
|||
end int64,
|
||||
config *config.TGConfig,
|
||||
concurrency int,
|
||||
) (*decrpytedReader, error) {
|
||||
) (*DecrpytedReader, error) {
|
||||
|
||||
r := &decrpytedReader{
|
||||
r := &DecrpytedReader{
|
||||
ctx: ctx,
|
||||
parts: parts,
|
||||
file: file,
|
||||
limit: end - start + 1,
|
||||
remaining: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
||||
config: config,
|
||||
client: client,
|
||||
|
@ -54,68 +54,63 @@ func NewDecryptedReader(
|
|||
concurrency: concurrency,
|
||||
cache: cache,
|
||||
}
|
||||
res, err := r.nextPart()
|
||||
|
||||
if err != nil {
|
||||
if err := r.initializeReader(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.reader = res
|
||||
|
||||
return r, nil
|
||||
|
||||
}
|
||||
|
||||
func (r *decrpytedReader) Read(p []byte) (int, error) {
|
||||
|
||||
if r.limit <= 0 {
|
||||
func (r *DecrpytedReader) Read(p []byte) (int, error) {
|
||||
if r.remaining <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
r.remaining -= int64(n)
|
||||
|
||||
if err == io.EOF {
|
||||
if r.limit > 0 {
|
||||
if err == io.EOF && r.remaining > 0 {
|
||||
if err := r.moveToNextPart(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
}
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
r.reader, err = r.nextPart()
|
||||
|
||||
}
|
||||
}
|
||||
r.limit -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
func (r *decrpytedReader) Close() (err error) {
|
||||
|
||||
func (r *DecrpytedReader) Close() error {
|
||||
if r.reader != nil {
|
||||
err = r.reader.Close()
|
||||
err := r.reader.Close()
|
||||
r.reader = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||
func (r *DecrpytedReader) initializeReader() error {
|
||||
reader, err := r.getPartReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.reader = reader
|
||||
return nil
|
||||
}
|
||||
|
||||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
func (r *DecrpytedReader) moveToNextPart() error {
|
||||
r.reader.Close()
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
return r.initializeReader()
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (r *DecrpytedReader) getPartReader() (io.ReadCloser, error) {
|
||||
currentRange := r.ranges[r.pos]
|
||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
||||
|
||||
return cipher.DecryptDataSeek(r.ctx,
|
||||
func(ctx context.Context,
|
||||
underlyingOffset,
|
||||
underlyingLimit int64) (io.ReadCloser, error) {
|
||||
var end int64
|
||||
|
||||
if underlyingLimit >= 0 {
|
||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||
}
|
||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
||||
partID := r.parts[currentRange.PartNo].ID
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.file.ChannelID,
|
||||
|
@ -126,11 +121,21 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
|||
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
|
||||
worker: r.worker,
|
||||
}
|
||||
|
||||
return cipher.DecryptDataSeek(r.ctx,
|
||||
func(ctx context.Context,
|
||||
underlyingOffset,
|
||||
underlyingLimit int64) (io.ReadCloser, error) {
|
||||
var end int64
|
||||
|
||||
if underlyingLimit >= 0 {
|
||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||
}
|
||||
|
||||
if r.concurrency < 2 {
|
||||
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
|
||||
}
|
||||
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
||||
|
||||
}, start, end-start+1)
|
||||
|
||||
}, currentRange.Start, currentRange.End-currentRange.Start+1)
|
||||
}
|
||||
|
|
|
@ -13,43 +13,19 @@ import (
|
|||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
||||
partByteRanges := []types.Range{}
|
||||
startPart := startByte / partSize
|
||||
endPart := endByte / partSize
|
||||
startOffset := startByte % partSize
|
||||
|
||||
for part := startPart; part <= endPart; part++ {
|
||||
partStartByte := int64(0)
|
||||
partEndByte := partSize - 1
|
||||
|
||||
if part == startPart {
|
||||
partStartByte = startOffset
|
||||
}
|
||||
if part == endPart {
|
||||
partEndByte = endByte % partSize
|
||||
}
|
||||
|
||||
partByteRanges = append(partByteRanges, types.Range{
|
||||
Start: partStartByte,
|
||||
End: partEndByte,
|
||||
PartNo: part,
|
||||
})
|
||||
|
||||
startOffset = 0
|
||||
}
|
||||
|
||||
return partByteRanges
|
||||
type Range struct {
|
||||
Start, End int64
|
||||
PartNo int64
|
||||
}
|
||||
|
||||
type LinearReader struct {
|
||||
ctx context.Context
|
||||
file *schemas.FileOutFull
|
||||
parts []types.Part
|
||||
ranges []types.Range
|
||||
ranges []Range
|
||||
pos int
|
||||
reader io.ReadCloser
|
||||
limit int64
|
||||
remaining int64
|
||||
config *config.TGConfig
|
||||
worker *tgc.StreamWorker
|
||||
client *tg.Client
|
||||
|
@ -57,6 +33,23 @@ type LinearReader struct {
|
|||
cache cache.Cacher
|
||||
}
|
||||
|
||||
func calculatePartByteRanges(start, end, partSize int64) []Range {
|
||||
ranges := make([]Range, 0)
|
||||
startPart := start / partSize
|
||||
endPart := end / partSize
|
||||
|
||||
for part := startPart; part <= endPart; part++ {
|
||||
partStart := max(start-part*partSize, 0)
|
||||
partEnd := min(partSize-1, end-part*partSize)
|
||||
ranges = append(ranges, Range{
|
||||
Start: partStart,
|
||||
End: partEnd,
|
||||
PartNo: part,
|
||||
})
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
func NewLinearReader(ctx context.Context,
|
||||
client *tg.Client,
|
||||
worker *tgc.StreamWorker,
|
||||
|
@ -73,7 +66,7 @@ func NewLinearReader(ctx context.Context,
|
|||
ctx: ctx,
|
||||
parts: parts,
|
||||
file: file,
|
||||
limit: end - start + 1,
|
||||
remaining: end - start + 1,
|
||||
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
||||
config: config,
|
||||
client: client,
|
||||
|
@ -82,42 +75,60 @@ func NewLinearReader(ctx context.Context,
|
|||
cache: cache,
|
||||
}
|
||||
|
||||
var err error
|
||||
r.reader, err = r.nextPart()
|
||||
if err != nil {
|
||||
if err := r.initializeReader(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *LinearReader) Read(p []byte) (int, error) {
|
||||
if r.limit <= 0 {
|
||||
if r.remaining <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
r.remaining -= int64(n)
|
||||
|
||||
if err == io.EOF && r.limit > 0 {
|
||||
if err == io.EOF && r.remaining > 0 {
|
||||
if err := r.moveToNextPart(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
err = nil
|
||||
if r.reader != nil {
|
||||
r.reader.Close()
|
||||
}
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
r.reader, err = r.nextPart()
|
||||
}
|
||||
}
|
||||
|
||||
r.limit -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
||||
start := r.ranges[r.pos].Start
|
||||
end := r.ranges[r.pos].End
|
||||
func (r *LinearReader) Close() error {
|
||||
if r.reader != nil {
|
||||
err := r.reader.Close()
|
||||
r.reader = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
||||
func (r *LinearReader) initializeReader() error {
|
||||
reader, err := r.getPartReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.reader = reader
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LinearReader) moveToNextPart() error {
|
||||
r.reader.Close()
|
||||
r.pos++
|
||||
if r.pos < len(r.ranges) {
|
||||
return r.initializeReader()
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
|
||||
currentRange := r.ranges[r.pos]
|
||||
partID := r.parts[currentRange.PartNo].ID
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.file.ChannelID,
|
||||
|
@ -130,16 +141,7 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
|||
}
|
||||
|
||||
if r.concurrency < 2 {
|
||||
return newTGReader(r.ctx, start, end, chunkSrc)
|
||||
return newTGReader(r.ctx, currentRange.Start, currentRange.End, chunkSrc)
|
||||
}
|
||||
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
|
||||
}
|
||||
|
||||
func (r *LinearReader) Close() error {
|
||||
if r.reader != nil {
|
||||
err := r.reader.Close()
|
||||
r.reader = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return newTGMultiReader(r.ctx, currentRange.Start, currentRange.End, r.config, chunkSrc)
|
||||
}
|
||||
|
|
|
@ -46,9 +46,3 @@ type BotInfo struct {
|
|||
AccessHash int64
|
||||
Token string
|
||||
}
|
||||
|
||||
type Range struct {
|
||||
Start int64
|
||||
End int64
|
||||
PartNo int64
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue