mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-11-10 09:02:52 +08:00
refactor: readers
This commit is contained in:
parent
a40304eeed
commit
619fc3eccd
3 changed files with 110 additions and 109 deletions
|
@ -14,14 +14,14 @@ import (
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
)
|
)
|
||||||
|
|
||||||
type decrpytedReader struct {
|
type DecrpytedReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
file *schemas.FileOutFull
|
file *schemas.FileOutFull
|
||||||
parts []types.Part
|
parts []types.Part
|
||||||
ranges []types.Range
|
ranges []Range
|
||||||
pos int
|
pos int
|
||||||
reader io.ReadCloser
|
reader io.ReadCloser
|
||||||
limit int64
|
remaining int64
|
||||||
config *config.TGConfig
|
config *config.TGConfig
|
||||||
worker *tgc.StreamWorker
|
worker *tgc.StreamWorker
|
||||||
client *tg.Client
|
client *tg.Client
|
||||||
|
@ -40,13 +40,13 @@ func NewDecryptedReader(
|
||||||
end int64,
|
end int64,
|
||||||
config *config.TGConfig,
|
config *config.TGConfig,
|
||||||
concurrency int,
|
concurrency int,
|
||||||
) (*decrpytedReader, error) {
|
) (*DecrpytedReader, error) {
|
||||||
|
|
||||||
r := &decrpytedReader{
|
r := &DecrpytedReader{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
parts: parts,
|
parts: parts,
|
||||||
file: file,
|
file: file,
|
||||||
limit: end - start + 1,
|
remaining: end - start + 1,
|
||||||
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
|
||||||
config: config,
|
config: config,
|
||||||
client: client,
|
client: client,
|
||||||
|
@ -54,57 +54,73 @@ func NewDecryptedReader(
|
||||||
concurrency: concurrency,
|
concurrency: concurrency,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
}
|
}
|
||||||
res, err := r.nextPart()
|
if err := r.initializeReader(); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r.reader = res
|
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *decrpytedReader) Read(p []byte) (int, error) {
|
func (r *DecrpytedReader) Read(p []byte) (int, error) {
|
||||||
|
if r.remaining <= 0 {
|
||||||
if r.limit <= 0 {
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := r.reader.Read(p)
|
n, err := r.reader.Read(p)
|
||||||
|
r.remaining -= int64(n)
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF && r.remaining > 0 {
|
||||||
if r.limit > 0 {
|
if err := r.moveToNextPart(); err != nil {
|
||||||
err = nil
|
return n, err
|
||||||
if r.reader != nil {
|
|
||||||
r.reader.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.pos++
|
|
||||||
if r.pos < len(r.ranges) {
|
|
||||||
r.reader, err = r.nextPart()
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
err = nil
|
||||||
}
|
}
|
||||||
r.limit -= int64(n)
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
func (r *decrpytedReader) Close() (err error) {
|
|
||||||
|
func (r *DecrpytedReader) Close() error {
|
||||||
if r.reader != nil {
|
if r.reader != nil {
|
||||||
err = r.reader.Close()
|
err := r.reader.Close()
|
||||||
r.reader = nil
|
r.reader = nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
func (r *DecrpytedReader) initializeReader() error {
|
||||||
|
reader, err := r.getPartReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.reader = reader
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
start := r.ranges[r.pos].Start
|
func (r *DecrpytedReader) moveToNextPart() error {
|
||||||
end := r.ranges[r.pos].End
|
r.reader.Close()
|
||||||
|
r.pos++
|
||||||
|
if r.pos < len(r.ranges) {
|
||||||
|
return r.initializeReader()
|
||||||
|
}
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DecrpytedReader) getPartReader() (io.ReadCloser, error) {
|
||||||
|
currentRange := r.ranges[r.pos]
|
||||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||||
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
||||||
|
partID := r.parts[currentRange.PartNo].ID
|
||||||
|
|
||||||
|
chunkSrc := &chunkSource{
|
||||||
|
channelID: r.file.ChannelID,
|
||||||
|
partID: partID,
|
||||||
|
client: r.client,
|
||||||
|
concurrency: r.concurrency,
|
||||||
|
cache: r.cache,
|
||||||
|
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
|
||||||
|
worker: r.worker,
|
||||||
|
}
|
||||||
|
|
||||||
return cipher.DecryptDataSeek(r.ctx,
|
return cipher.DecryptDataSeek(r.ctx,
|
||||||
func(ctx context.Context,
|
func(ctx context.Context,
|
||||||
|
@ -115,22 +131,11 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
|
||||||
if underlyingLimit >= 0 {
|
if underlyingLimit >= 0 {
|
||||||
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
|
||||||
}
|
}
|
||||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
|
||||||
|
|
||||||
chunkSrc := &chunkSource{
|
|
||||||
channelID: r.file.ChannelID,
|
|
||||||
partID: partID,
|
|
||||||
client: r.client,
|
|
||||||
concurrency: r.concurrency,
|
|
||||||
cache: r.cache,
|
|
||||||
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
|
|
||||||
worker: r.worker,
|
|
||||||
}
|
|
||||||
if r.concurrency < 2 {
|
if r.concurrency < 2 {
|
||||||
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
|
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
|
||||||
}
|
}
|
||||||
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
|
||||||
|
|
||||||
}, start, end-start+1)
|
}, currentRange.Start, currentRange.End-currentRange.Start+1)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,43 +13,19 @@ import (
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func calculatePartByteRanges(startByte, endByte, partSize int64) []types.Range {
|
type Range struct {
|
||||||
partByteRanges := []types.Range{}
|
Start, End int64
|
||||||
startPart := startByte / partSize
|
PartNo int64
|
||||||
endPart := endByte / partSize
|
|
||||||
startOffset := startByte % partSize
|
|
||||||
|
|
||||||
for part := startPart; part <= endPart; part++ {
|
|
||||||
partStartByte := int64(0)
|
|
||||||
partEndByte := partSize - 1
|
|
||||||
|
|
||||||
if part == startPart {
|
|
||||||
partStartByte = startOffset
|
|
||||||
}
|
|
||||||
if part == endPart {
|
|
||||||
partEndByte = endByte % partSize
|
|
||||||
}
|
|
||||||
|
|
||||||
partByteRanges = append(partByteRanges, types.Range{
|
|
||||||
Start: partStartByte,
|
|
||||||
End: partEndByte,
|
|
||||||
PartNo: part,
|
|
||||||
})
|
|
||||||
|
|
||||||
startOffset = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return partByteRanges
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type LinearReader struct {
|
type LinearReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
file *schemas.FileOutFull
|
file *schemas.FileOutFull
|
||||||
parts []types.Part
|
parts []types.Part
|
||||||
ranges []types.Range
|
ranges []Range
|
||||||
pos int
|
pos int
|
||||||
reader io.ReadCloser
|
reader io.ReadCloser
|
||||||
limit int64
|
remaining int64
|
||||||
config *config.TGConfig
|
config *config.TGConfig
|
||||||
worker *tgc.StreamWorker
|
worker *tgc.StreamWorker
|
||||||
client *tg.Client
|
client *tg.Client
|
||||||
|
@ -57,6 +33,23 @@ type LinearReader struct {
|
||||||
cache cache.Cacher
|
cache cache.Cacher
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func calculatePartByteRanges(start, end, partSize int64) []Range {
|
||||||
|
ranges := make([]Range, 0)
|
||||||
|
startPart := start / partSize
|
||||||
|
endPart := end / partSize
|
||||||
|
|
||||||
|
for part := startPart; part <= endPart; part++ {
|
||||||
|
partStart := max(start-part*partSize, 0)
|
||||||
|
partEnd := min(partSize-1, end-part*partSize)
|
||||||
|
ranges = append(ranges, Range{
|
||||||
|
Start: partStart,
|
||||||
|
End: partEnd,
|
||||||
|
PartNo: part,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return ranges
|
||||||
|
}
|
||||||
|
|
||||||
func NewLinearReader(ctx context.Context,
|
func NewLinearReader(ctx context.Context,
|
||||||
client *tg.Client,
|
client *tg.Client,
|
||||||
worker *tgc.StreamWorker,
|
worker *tgc.StreamWorker,
|
||||||
|
@ -73,7 +66,7 @@ func NewLinearReader(ctx context.Context,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
parts: parts,
|
parts: parts,
|
||||||
file: file,
|
file: file,
|
||||||
limit: end - start + 1,
|
remaining: end - start + 1,
|
||||||
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
ranges: calculatePartByteRanges(start, end, parts[0].Size),
|
||||||
config: config,
|
config: config,
|
||||||
client: client,
|
client: client,
|
||||||
|
@ -82,42 +75,60 @@ func NewLinearReader(ctx context.Context,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
if err := r.initializeReader(); err != nil {
|
||||||
r.reader, err = r.nextPart()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *LinearReader) Read(p []byte) (int, error) {
|
func (r *LinearReader) Read(p []byte) (int, error) {
|
||||||
if r.limit <= 0 {
|
if r.remaining <= 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := r.reader.Read(p)
|
n, err := r.reader.Read(p)
|
||||||
|
r.remaining -= int64(n)
|
||||||
|
|
||||||
if err == io.EOF && r.limit > 0 {
|
if err == io.EOF && r.remaining > 0 {
|
||||||
|
if err := r.moveToNextPart(); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
err = nil
|
err = nil
|
||||||
if r.reader != nil {
|
|
||||||
r.reader.Close()
|
|
||||||
}
|
|
||||||
r.pos++
|
|
||||||
if r.pos < len(r.ranges) {
|
|
||||||
r.reader, err = r.nextPart()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.limit -= int64(n)
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
func (r *LinearReader) Close() error {
|
||||||
start := r.ranges[r.pos].Start
|
if r.reader != nil {
|
||||||
end := r.ranges[r.pos].End
|
err := r.reader.Close()
|
||||||
|
r.reader = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
partID := r.parts[r.ranges[r.pos].PartNo].ID
|
func (r *LinearReader) initializeReader() error {
|
||||||
|
reader, err := r.getPartReader()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.reader = reader
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LinearReader) moveToNextPart() error {
|
||||||
|
r.reader.Close()
|
||||||
|
r.pos++
|
||||||
|
if r.pos < len(r.ranges) {
|
||||||
|
return r.initializeReader()
|
||||||
|
}
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
|
||||||
|
currentRange := r.ranges[r.pos]
|
||||||
|
partID := r.parts[currentRange.PartNo].ID
|
||||||
|
|
||||||
chunkSrc := &chunkSource{
|
chunkSrc := &chunkSource{
|
||||||
channelID: r.file.ChannelID,
|
channelID: r.file.ChannelID,
|
||||||
|
@ -130,16 +141,7 @@ func (r *LinearReader) nextPart() (io.ReadCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.concurrency < 2 {
|
if r.concurrency < 2 {
|
||||||
return newTGReader(r.ctx, start, end, chunkSrc)
|
return newTGReader(r.ctx, currentRange.Start, currentRange.End, chunkSrc)
|
||||||
}
|
}
|
||||||
return newTGMultiReader(r.ctx, start, end, r.config, chunkSrc)
|
return newTGMultiReader(r.ctx, currentRange.Start, currentRange.End, r.config, chunkSrc)
|
||||||
}
|
|
||||||
|
|
||||||
func (r *LinearReader) Close() error {
|
|
||||||
if r.reader != nil {
|
|
||||||
err := r.reader.Close()
|
|
||||||
r.reader = nil
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,9 +46,3 @@ type BotInfo struct {
|
||||||
AccessHash int64
|
AccessHash int64
|
||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Range struct {
|
|
||||||
Start int64
|
|
||||||
End int64
|
|
||||||
PartNo int64
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue