teldrive/internal/reader/decrypted_reader.go
2024-08-31 15:50:17 +00:00

142 lines
3.1 KiB
Go

package reader
import (
"context"
"fmt"
"io"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/crypt"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/pkg/schemas"
"github.com/tgdrive/teldrive/pkg/types"
)
type DecrpytedReader struct {
ctx context.Context
file *schemas.FileOutFull
parts []types.Part
ranges []Range
pos int
reader io.ReadCloser
remaining int64
config *config.TGConfig
worker *tgc.StreamWorker
client *tg.Client
concurrency int
cache cache.Cacher
}
func NewDecryptedReader(
ctx context.Context,
client *tg.Client,
worker *tgc.StreamWorker,
cache cache.Cacher,
file *schemas.FileOutFull,
parts []types.Part,
start,
end int64,
config *config.TGConfig,
concurrency int,
) (*DecrpytedReader, error) {
r := &DecrpytedReader{
ctx: ctx,
parts: parts,
file: file,
remaining: end - start + 1,
ranges: calculatePartByteRanges(start, end, parts[0].DecryptedSize),
config: config,
client: client,
worker: worker,
concurrency: concurrency,
cache: cache,
}
if err := r.initializeReader(); err != nil {
return nil, err
}
return r, nil
}
func (r *DecrpytedReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, io.EOF
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
if err == io.EOF && r.remaining > 0 {
if err := r.moveToNextPart(); err != nil {
return n, err
}
err = nil
}
return n, err
}
func (r *DecrpytedReader) Close() error {
if r.reader != nil {
err := r.reader.Close()
r.reader = nil
return err
}
return nil
}
func (r *DecrpytedReader) initializeReader() error {
reader, err := r.getPartReader()
if err != nil {
return err
}
r.reader = reader
return nil
}
func (r *DecrpytedReader) moveToNextPart() error {
r.reader.Close()
r.pos++
if r.pos < len(r.ranges) {
return r.initializeReader()
}
return io.EOF
}
func (r *DecrpytedReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
partID := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelID,
partID: partID,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.Id, partID),
worker: r.worker,
}
return cipher.DecryptDataSeek(r.ctx,
func(ctx context.Context,
underlyingOffset,
underlyingLimit int64) (io.ReadCloser, error) {
var end int64
if underlyingLimit >= 0 {
end = min(r.parts[r.ranges[r.pos].PartNo].Size-1, underlyingOffset+underlyingLimit-1)
}
if r.concurrency < 2 {
return newTGReader(r.ctx, underlyingOffset, end, chunkSrc)
}
return newTGMultiReader(r.ctx, underlyingOffset, end, r.config, chunkSrc)
}, currentRange.Start, currentRange.End-currentRange.Start+1)
}