enable native file encryption

This commit is contained in:
divyam234 2023-12-08 03:16:06 +05:30
parent e685ac67d2
commit 3e30b8797c
19 changed files with 906 additions and 71 deletions

View file

@ -31,7 +31,7 @@ Telegram Drive is a powerful utility that enables you to create your own cloud s
## Features
- **UI:** Based on Material You to create nice looking UI themes.
- **Secure:** Your data is secured using Telegram's robust encryption.
- **Secure:** Your data is secured using robust encryption.
- **Flexible Deployment:** Use Docker Compose or deploy without Docker.
## Demo
@ -135,6 +135,16 @@ In addition to the mandatory variables, you can also set the following optional
- `UPLOAD_RETENTION` : No of days to keep incomplete uploads parts in channel afterwards these parts are deleted (Default 15).
- `ENCRYPTION_KEY` : Password for Encryption.
- `ENCRYPTION_SALT` : Salt for Encryption.
> [!WARNING]
> Keep your Passoword and Salt safe once generated teldrive uses same encryption as of rclone internally
so you don't need to enable crypt in rclone.Enabling crypt in rclone makes UI reduntant so encrypting files
in teldrive internally is better way to encrypt files instead of enabling in rclone.To encrypt files see more
about teldrive rclone config.
### For making use of Multi Bots support
> [!WARNING]

View file

@ -11,6 +11,8 @@ func InitRouter() *gin.Engine {
r := gin.Default()
r.Use(gin.Recovery())
r.Use(middleware.Cors())
c := controller.NewController()

View file

@ -33,6 +33,8 @@ type Config struct {
BgBotsLimit int `envconfig:"BG_BOTS_LIMIT" default:"5"`
UploadRetention int `envconfig:"UPLOAD_RETENTION" default:"15"`
DisableStreamBots bool `envconfig:"DISABLE_STREAM_BOTS" default:"false"`
EncryptionKey string `envconfig:"ENCRYPTION_KEY"`
EncryptionSalt string `envconfig:"ENCRYPTION_SALT"`
ExecDir string
}

1
go.mod
View file

@ -27,6 +27,7 @@ require (
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/google/uuid v1.4.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/rfjakob/eme v1.1.2 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sethvargo/go-retry v0.2.4 // indirect
)

2
go.sum
View file

@ -209,6 +209,8 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rfjakob/eme v1.1.2 h1:SxziR8msSOElPayZNFfQw4Tjx/Sbaeeh3eRvrHVMUs4=
github.com/rfjakob/eme v1.1.2/go.mod h1:cVvpasglm/G3ngEfcfT/Wt0GwhkuO32pf/poW6Nyk1k=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=

560
internal/crypt/cipher.go Normal file
View file

@ -0,0 +1,560 @@
package crypt
import (
"bytes"
"context"
"crypto/aes"
gocipher "crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
"sync"
"golang.org/x/crypto/nacl/secretbox"
"golang.org/x/crypto/scrypt"
)
const (
nameCipherBlockSize = aes.BlockSize
fileMagic = "TELDRIVE\x00\x00"
fileMagicSize = len(fileMagic)
fileNonceSize = 24
fileHeaderSize = fileMagicSize + fileNonceSize
blockHeaderSize = secretbox.Overhead
blockDataSize = 64 * 1024
blockSize = blockHeaderSize + blockDataSize
)
var (
ErrorBadDecryptUTF8 = errors.New("bad decryption - utf-8 invalid")
ErrorBadDecryptControlChar = errors.New("bad decryption - contains control chars")
ErrorNotAMultipleOfBlocksize = errors.New("not a multiple of blocksize")
ErrorTooShortAfterDecode = errors.New("too short after base32 decode")
ErrorTooLongAfterDecode = errors.New("too long after base32 decode")
ErrorEncryptedFileTooShort = errors.New("file is too short to be encrypted")
ErrorEncryptedFileBadHeader = errors.New("file has truncated block header")
ErrorEncryptedBadMagic = errors.New("not an encrypted file - bad magic string")
ErrorEncryptedBadBlock = errors.New("failed to authenticate decrypted block - bad password?")
ErrorBadBase32Encoding = errors.New("bad base32 filename encoding")
ErrorFileClosed = errors.New("file already closed")
ErrorNotAnEncryptedFile = errors.New("not an encrypted file - does not match suffix")
ErrorBadSeek = errors.New("Seek beyond end of file")
ErrorSuffixMissingDot = errors.New("suffix config setting should include a '.'")
defaultSalt = []byte{0xA8, 0x0D, 0xF4, 0x3A, 0x8F, 0xBD, 0x03, 0x08, 0xA7, 0xCA, 0xB8, 0x3E, 0x58, 0x1F, 0x86, 0xB1}
)
var (
fileMagicBytes = []byte(fileMagic)
)
type ReadSeekCloser interface {
io.Reader
io.Seeker
io.Closer
}
type OpenRangeSeek func(ctx context.Context, offset, limit int64) (io.ReadCloser, error)
func readFill(r io.Reader, buf []byte) (n int, err error) {
var nn int
for n < len(buf) && err == nil {
nn, err = r.Read(buf[n:])
n += nn
}
return n, err
}
type Cipher struct {
dataKey [32]byte
nameKey [32]byte
nameTweak [nameCipherBlockSize]byte
block gocipher.Block
buffers sync.Pool
cryptoRand io.Reader
}
func NewCipher(password, salt string) (*Cipher, error) {
c := &Cipher{
cryptoRand: rand.Reader,
}
c.buffers.New = func() interface{} {
return new([blockSize]byte)
}
err := c.Key(password, salt)
if err != nil {
return nil, err
}
return c, nil
}
func (c *Cipher) Key(password, salt string) (err error) {
const keySize = len(c.dataKey) + len(c.nameKey) + len(c.nameTweak)
var saltBytes = defaultSalt
if salt != "" {
saltBytes = []byte(salt)
}
var key []byte
if password == "" {
key = make([]byte, keySize)
} else {
key, err = scrypt.Key([]byte(password), saltBytes, 16384, 8, 1, keySize)
if err != nil {
return err
}
}
copy(c.dataKey[:], key)
copy(c.nameKey[:], key[len(c.dataKey):])
copy(c.nameTweak[:], key[len(c.dataKey)+len(c.nameKey):])
c.block, err = aes.NewCipher(c.nameKey[:])
return err
}
func (c *Cipher) getBlock() *[blockSize]byte {
return c.buffers.Get().(*[blockSize]byte)
}
func (c *Cipher) putBlock(buf *[blockSize]byte) {
c.buffers.Put(buf)
}
type nonce [fileNonceSize]byte
func (n *nonce) pointer() *[fileNonceSize]byte {
return (*[fileNonceSize]byte)(n)
}
func (n *nonce) fromReader(in io.Reader) error {
read, err := readFill(in, (*n)[:])
if read != fileNonceSize {
return fmt.Errorf("short read of nonce: %w", err)
}
return nil
}
func (n *nonce) fromBuf(buf []byte) {
read := copy((*n)[:], buf)
if read != fileNonceSize {
panic("buffer to short to read nonce")
}
}
func (n *nonce) carry(i int) {
for ; i < len(*n); i++ {
digit := (*n)[i]
newDigit := digit + 1
(*n)[i] = newDigit
if newDigit >= digit {
// exit if no carry
break
}
}
}
func (n *nonce) increment() {
n.carry(0)
}
func (n *nonce) add(x uint64) {
carry := uint16(0)
for i := 0; i < 8; i++ {
digit := (*n)[i]
xDigit := byte(x)
x >>= 8
carry += uint16(digit) + uint16(xDigit)
(*n)[i] = byte(carry)
carry >>= 8
}
if carry != 0 {
n.carry(8)
}
}
type encrypter struct {
mu sync.Mutex
in io.Reader
c *Cipher
nonce nonce
buf *[blockSize]byte
readBuf *[blockSize]byte
bufIndex int
bufSize int
err error
}
func (c *Cipher) newEncrypter(in io.Reader, nonce *nonce) (*encrypter, error) {
fh := &encrypter{
in: in,
c: c,
buf: c.getBlock(),
readBuf: c.getBlock(),
bufSize: fileHeaderSize,
}
if nonce != nil {
fh.nonce = *nonce
} else {
err := fh.nonce.fromReader(c.cryptoRand)
if err != nil {
return nil, err
}
}
copy((*fh.buf)[:], fileMagicBytes)
copy((*fh.buf)[fileMagicSize:], fh.nonce[:])
return fh, nil
}
func (fh *encrypter) Read(p []byte) (n int, err error) {
fh.mu.Lock()
defer fh.mu.Unlock()
if fh.err != nil {
return 0, fh.err
}
if fh.bufIndex >= fh.bufSize {
readBuf := (*fh.readBuf)[:blockDataSize]
n, err = readFill(fh.in, readBuf)
if n == 0 {
return fh.finish(err)
}
secretbox.Seal((*fh.buf)[:0], readBuf[:n], fh.nonce.pointer(), &fh.c.dataKey)
fh.bufIndex = 0
fh.bufSize = blockHeaderSize + n
fh.nonce.increment()
}
n = copy(p, (*fh.buf)[fh.bufIndex:fh.bufSize])
fh.bufIndex += n
return n, nil
}
func (fh *encrypter) finish(err error) (int, error) {
if fh.err != nil {
return 0, fh.err
}
fh.err = err
fh.c.putBlock(fh.buf)
fh.buf = nil
fh.c.putBlock(fh.readBuf)
fh.readBuf = nil
return 0, err
}
func (fh *encrypter) Close() error {
return nil
}
func (c *Cipher) EncryptData(in io.Reader) (io.ReadCloser, error) {
return c.newEncrypter(in, nil)
}
type decrypter struct {
mu sync.Mutex
rc io.ReadCloser
nonce nonce
initialNonce nonce
c *Cipher
buf *[blockSize]byte
readBuf *[blockSize]byte
bufIndex int
bufSize int
err error
limit int64
open OpenRangeSeek
}
func (c *Cipher) newDecrypter(rc io.ReadCloser) (*decrypter, error) {
fh := &decrypter{
rc: rc,
c: c,
buf: c.getBlock(),
readBuf: c.getBlock(),
limit: -1,
}
readBuf := (*fh.readBuf)[:fileHeaderSize]
n, err := readFill(fh.rc, readBuf)
if n < fileHeaderSize && err == io.EOF {
return nil, fh.finishAndClose(ErrorEncryptedFileTooShort)
} else if err != io.EOF && err != nil {
return nil, fh.finishAndClose(err)
}
if !bytes.Equal(readBuf[:fileMagicSize], fileMagicBytes) {
return nil, fh.finishAndClose(ErrorEncryptedBadMagic)
}
fh.nonce.fromBuf(readBuf[fileMagicSize:])
fh.initialNonce = fh.nonce
return fh, nil
}
func (c *Cipher) newDecrypterSeek(ctx context.Context, open OpenRangeSeek, offset, limit int64) (fh *decrypter, err error) {
var rc io.ReadCloser
doRangeSeek := false
setLimit := false
if offset == 0 && limit < 0 {
rc, err = open(ctx, 0, -1)
} else if offset == 0 {
_, underlyingLimit, _, _ := calculateUnderlying(offset, limit)
rc, err = open(ctx, 0, int64(fileHeaderSize)+underlyingLimit)
setLimit = true
} else {
rc, err = open(ctx, 0, int64(fileHeaderSize))
doRangeSeek = true
}
if err != nil {
return nil, err
}
fh, err = c.newDecrypter(rc)
if err != nil {
return nil, err
}
fh.open = open
if doRangeSeek {
_, err = fh.RangeSeek(ctx, offset, io.SeekStart, limit)
if err != nil {
_ = fh.Close()
return nil, err
}
}
if setLimit {
fh.limit = limit
}
return fh, nil
}
func (fh *decrypter) fillBuffer() (err error) {
readBuf := fh.readBuf
n, err := readFill(fh.rc, (*readBuf)[:])
if n == 0 {
return err
}
if n <= blockHeaderSize {
if err != nil && err != io.EOF {
return err
}
return ErrorEncryptedFileBadHeader
}
_, ok := secretbox.Open((*fh.buf)[:0], (*readBuf)[:n], fh.nonce.pointer(), &fh.c.dataKey)
if !ok {
if err != nil && err != io.EOF {
return err
}
for i := range (*fh.buf)[:n] {
(*fh.buf)[i] = 0
}
}
fh.bufIndex = 0
fh.bufSize = n - blockHeaderSize
fh.nonce.increment()
return nil
}
func (fh *decrypter) Read(p []byte) (n int, err error) {
fh.mu.Lock()
defer fh.mu.Unlock()
if fh.err != nil {
return 0, fh.err
}
if fh.bufIndex >= fh.bufSize {
err = fh.fillBuffer()
if err != nil {
return 0, fh.finish(err)
}
}
toCopy := fh.bufSize - fh.bufIndex
if fh.limit >= 0 && fh.limit < int64(toCopy) {
toCopy = int(fh.limit)
}
n = copy(p, (*fh.buf)[fh.bufIndex:fh.bufIndex+toCopy])
fh.bufIndex += n
if fh.limit >= 0 {
fh.limit -= int64(n)
if fh.limit == 0 {
return n, fh.finish(io.EOF)
}
}
return n, nil
}
func calculateUnderlying(offset, limit int64) (underlyingOffset, underlyingLimit, discard, blocks int64) {
blocks, discard = offset/blockDataSize, offset%blockDataSize
underlyingOffset = int64(fileHeaderSize) + blocks*(blockHeaderSize+blockDataSize)
underlyingLimit = int64(-1)
if limit >= 0 {
bytesToRead := limit - (blockDataSize - discard)
blocksToRead := int64(1)
if bytesToRead > 0 {
extraBlocksToRead, endBytes := bytesToRead/blockDataSize, bytesToRead%blockDataSize
if endBytes != 0 {
extraBlocksToRead++
}
blocksToRead += extraBlocksToRead
}
underlyingLimit = blocksToRead * (blockHeaderSize + blockDataSize)
}
return
}
func (fh *decrypter) RangeSeek(ctx context.Context, offset int64, whence int, limit int64) (int64, error) {
fh.mu.Lock()
defer fh.mu.Unlock()
if fh.open == nil {
return 0, fh.finish(errors.New("can't seek - not initialised with newDecrypterSeek"))
}
if whence != io.SeekStart {
return 0, fh.finish(errors.New("can only seek from the start"))
}
if fh.err == io.EOF {
fh.unFinish()
} else if fh.err != nil {
return 0, fh.err
}
underlyingOffset, underlyingLimit, discard, blocks := calculateUnderlying(offset, limit)
fh.nonce = fh.initialNonce
fh.nonce.add(uint64(blocks))
rc, err := fh.open(ctx, underlyingOffset, underlyingLimit)
if err != nil {
return 0, fh.finish(fmt.Errorf("couldn't reopen file with offset and limit: %w", err))
}
fh.rc = rc
err = fh.fillBuffer()
if err != nil {
return 0, fh.finish(err)
}
if int(discard) > fh.bufSize {
return 0, fh.finish(ErrorBadSeek)
}
fh.bufIndex = int(discard)
fh.limit = limit
return offset, nil
}
func (fh *decrypter) Seek(offset int64, whence int) (int64, error) {
return fh.RangeSeek(context.TODO(), offset, whence, -1)
}
func (fh *decrypter) finish(err error) error {
if fh.err != nil {
return fh.err
}
fh.err = err
fh.c.putBlock(fh.buf)
fh.buf = nil
fh.c.putBlock(fh.readBuf)
fh.readBuf = nil
return err
}
func (fh *decrypter) unFinish() {
fh.err = nil
fh.buf = fh.c.getBlock()
fh.readBuf = fh.c.getBlock()
fh.bufIndex = 0
fh.bufSize = 0
}
func (fh *decrypter) Close() error {
fh.mu.Lock()
defer fh.mu.Unlock()
if fh.err == ErrorFileClosed {
return fh.err
}
if fh.err == nil {
_ = fh.finish(io.EOF)
}
fh.err = ErrorFileClosed
if fh.rc == nil {
return nil
}
return fh.rc.Close()
}
func (fh *decrypter) finishAndClose(err error) error {
_ = fh.finish(err)
_ = fh.Close()
return err
}
func (c *Cipher) DecryptData(rc io.ReadCloser) (io.ReadCloser, error) {
out, err := c.newDecrypter(rc)
if err != nil {
return nil, err
}
return out, nil
}
func (c *Cipher) DecryptDataSeek(ctx context.Context, open OpenRangeSeek, offset, limit int64) (ReadSeekCloser, error) {
out, err := c.newDecrypterSeek(ctx, open, offset, limit)
if err != nil {
return nil, err
}
return out, nil
}
func (c *Cipher) EncryptedSize(size int64) int64 {
blocks, residue := size/blockDataSize, size%blockDataSize
encryptedSize := int64(fileHeaderSize) + blocks*(blockHeaderSize+blockDataSize)
if residue != 0 {
encryptedSize += blockHeaderSize + residue
}
return encryptedSize
}
func (c *Cipher) DecryptedSize(size int64) (int64, error) {
size -= int64(fileHeaderSize)
if size < 0 {
return 0, ErrorEncryptedFileTooShort
}
blocks, residue := size/blockSize, size%blockSize
decryptedSize := blocks * blockDataSize
if residue != 0 {
residue -= blockHeaderSize
if residue <= 0 {
return 0, ErrorEncryptedFileBadHeader
}
}
decryptedSize += residue
return decryptedSize, nil
}

View file

@ -0,0 +1,52 @@
package pkcs7
import "errors"
var (
ErrorPaddingNotFound = errors.New("bad PKCS#7 padding - not padded")
ErrorPaddingNotAMultiple = errors.New("bad PKCS#7 padding - not a multiple of blocksize")
ErrorPaddingTooLong = errors.New("bad PKCS#7 padding - too long")
ErrorPaddingTooShort = errors.New("bad PKCS#7 padding - too short")
ErrorPaddingNotAllTheSame = errors.New("bad PKCS#7 padding - not all the same")
)
func Pad(n int, buf []byte) []byte {
if n <= 1 || n >= 256 {
panic("bad multiple")
}
length := len(buf)
padding := n - (length % n)
for i := 0; i < padding; i++ {
buf = append(buf, byte(padding))
}
if (len(buf) % n) != 0 {
panic("padding failed")
}
return buf
}
func Unpad(n int, buf []byte) ([]byte, error) {
if n <= 1 || n >= 256 {
panic("bad multiple")
}
length := len(buf)
if length == 0 {
return nil, ErrorPaddingNotFound
}
if (length % n) != 0 {
return nil, ErrorPaddingNotAMultiple
}
padding := int(buf[length-1])
if padding > n {
return nil, ErrorPaddingTooLong
}
if padding == 0 {
return nil, ErrorPaddingTooShort
}
for i := 0; i < padding; i++ {
if buf[length-1-i] != byte(padding) {
return nil, ErrorPaddingNotAllTheSame
}
}
return buf[:length-padding], nil
}

View file

@ -0,0 +1,101 @@
package reader
import (
"context"
"io"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
)
type decrpytedReader struct {
ctx context.Context
parts []types.Part
pos int
client *telegram.Client
reader io.ReadCloser
bytesread int64
contentLength int64
cipher *crypt.Cipher
}
func NewDecryptedReader(
ctx context.Context,
client *telegram.Client,
parts []types.Part,
cipher *crypt.Cipher,
contentLength int64) (io.ReadCloser, error) {
r := &decrpytedReader{
ctx: ctx,
parts: parts,
client: client,
contentLength: contentLength,
cipher: cipher,
}
res, err := r.nextPart()
if err != nil {
return nil, err
}
r.reader = res
return r, nil
}
func (r *decrpytedReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
if err == io.EOF || n == 0 {
r.pos++
if r.pos < len(r.parts) {
r.reader, err = r.nextPart()
if err != nil {
return 0, err
}
}
}
r.bytesread += int64(n)
if r.bytesread == r.contentLength {
return n, io.EOF
}
return n, nil
}
func (r *decrpytedReader) Close() (err error) {
if r.reader != nil {
err = r.reader.Close()
r.reader = nil
return err
}
return nil
}
func (r *decrpytedReader) nextPart() (io.ReadCloser, error) {
return r.cipher.DecryptDataSeek(r.ctx,
func(ctx context.Context,
underlyingOffset,
underlyingLimit int64) (io.ReadCloser, error) {
var end int64
if underlyingLimit >= 0 {
end = min(r.parts[r.pos].Size-1, underlyingOffset+underlyingLimit-1)
}
return NewTGReader(r.ctx, r.client, types.Part{
Start: underlyingOffset,
End: end,
Location: r.parts[r.pos].Location,
})
}, r.parts[r.pos].Start, r.parts[r.pos].End-r.parts[r.pos].Start+1)
}

73
internal/reader/reader.go Normal file
View file

@ -0,0 +1,73 @@
package reader
import (
"context"
"io"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gotd/td/telegram"
)
type linearReader struct {
ctx context.Context
parts []types.Part
pos int
client *telegram.Client
reader io.ReadCloser
bytesread int64
contentLength int64
}
func NewLinearReader(ctx context.Context,
client *telegram.Client,
parts []types.Part,
contentLength int64,
) (reader io.ReadCloser, err error) {
r := &linearReader{
ctx: ctx,
parts: parts,
client: client,
contentLength: contentLength,
}
reader, err = NewTGReader(r.ctx, r.client, r.parts[r.pos])
if err != nil {
return nil, err
}
r.reader = reader
return r, nil
}
func (r *linearReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
if err == io.EOF || n == 0 {
r.pos++
if r.pos < len(r.parts) {
r.reader, err = NewTGReader(r.ctx, r.client, r.parts[r.pos])
if err != nil {
return 0, err
}
}
}
r.bytesread += int64(n)
if r.bytesread == r.contentLength {
return n, io.EOF
}
return n, nil
}
func (r *linearReader) Close() (err error) {
if r.reader != nil {
err = r.reader.Close()
r.reader = nil
return err
}
return nil
}

View file

@ -10,43 +10,39 @@ import (
"github.com/gotd/td/tg"
)
type linearReader struct {
ctx context.Context
parts []types.Part
pos int
client *telegram.Client
next func() ([]byte, error)
buffer []byte
bytesread int64
chunkSize int64
i int64
contentLength int64
type tgReader struct {
ctx context.Context
client *telegram.Client
location *tg.InputDocumentFileLocation
start int64
end int64
next func() ([]byte, error)
buffer []byte
bytesread int64
chunkSize int64
i int64
}
func (*linearReader) Close() error {
return nil
}
func NewTGReader(
ctx context.Context,
client *telegram.Client,
part types.Part,
func NewLinearReader(ctx context.Context, client *telegram.Client, parts []types.Part, contentLength int64) (io.ReadCloser, error) {
) (io.ReadCloser, error) {
r := &linearReader{
ctx: ctx,
parts: parts,
client: client,
chunkSize: int64(1024 * 1024),
contentLength: contentLength,
r := &tgReader{
ctx: ctx,
location: part.Location,
client: client,
start: part.Start,
end: part.End,
chunkSize: int64(1024 * 1024),
}
r.next = r.partStream()
return r, nil
}
func (r *linearReader) Read(p []byte) (n int, err error) {
if r.bytesread == r.contentLength {
return 0, io.EOF
}
func (r *tgReader) Read(p []byte) (n int, err error) {
if r.i >= int64(len(r.buffer)) {
r.buffer, err = r.next()
@ -54,36 +50,35 @@ func (r *linearReader) Read(p []byte) (n int, err error) {
return 0, err
}
if len(r.buffer) == 0 {
r.pos++
if r.pos == len(r.parts) {
return 0, io.EOF
} else {
r.next = r.partStream()
r.buffer, err = r.next()
if err != nil {
return 0, err
}
r.next = r.partStream()
r.buffer, err = r.next()
if err != nil {
return 0, err
}
}
r.i = 0
}
n = copy(p, r.buffer[r.i:])
r.i += int64(n)
r.bytesread += int64(n)
if r.bytesread == r.end-r.start+1 {
return n, io.EOF
}
return n, nil
}
func (r *linearReader) chunk(offset int64, limit int64) ([]byte, error) {
func (*tgReader) Close() error {
return nil
}
func (r *tgReader) chunk(offset int64, limit int64) ([]byte, error) {
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: r.parts[r.pos].Location,
Location: r.location,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
@ -100,51 +95,38 @@ func (r *linearReader) chunk(offset int64, limit int64) ([]byte, error) {
}
}
func (r *linearReader) partStream() func() ([]byte, error) {
func (r *tgReader) partStream() func() ([]byte, error) {
start := r.parts[r.pos].Start
end := r.parts[r.pos].End
start := r.start
end := r.end
offset := start - (start % r.chunkSize)
firstPartCut := start - offset
lastPartCut := (end % r.chunkSize) + 1
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
readData := func() ([]byte, error) {
if currentPart > partCount {
return make([]byte, 0), nil
}
res, err := r.chunk(offset, r.chunkSize)
if err != nil {
return nil, err
}
if len(res) == 0 {
return res, nil
} else if partCount == 1 {
res = res[firstPartCut:lastPartCut]
} else if currentPart == 1 {
res = res[firstPartCut:]
} else if currentPart == partCount {
res = res[:lastPartCut]
}
currentPart++
offset += r.chunkSize
return res, nil
}
return readData
}

View file

@ -0,0 +1,4 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE teldrive.files ADD COLUMN "encrypted" BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd

View file

@ -35,6 +35,7 @@ func ToFileOutFull(file models.File) *schemas.FileOutFull {
FileOut: ToFileOut(file),
Parts: parts,
ChannelID: *file.ChannelID,
Encrypted: file.Encrypted,
}
}

View file

@ -15,6 +15,7 @@ type File struct {
Size *int64 `gorm:"type:bigint"`
Starred bool `gorm:"default:false"`
Depth *int `gorm:"type:integer"`
Encrypted bool `gorm:"default:false"`
UserID int64 `gorm:"type:bigint;not null"`
Status string `gorm:"type:text"`
ParentID string `gorm:"type:text;index"`

View file

@ -41,6 +41,7 @@ type CreateFile struct {
Path string `json:"path" binding:"required"`
Size int64 `json:"size"`
ParentID string `json:"parentId"`
Encrypted bool `json:"encrypted"`
}
type FileOut struct {
@ -59,6 +60,7 @@ type FileOutFull struct {
FileOut
Parts []Part `json:"parts,omitempty"`
ChannelID int64 `json:"channelId"`
Encrypted bool `json:"encrypted"`
}
type UpdateFile struct {

View file

@ -1,9 +1,10 @@
package schemas
type UploadQuery struct {
Filename string `form:"fileName" binding:"required`
PartNo int `form:"partNo" binding:"required`
ChannelID int64 `form:"channelId" binding:"required`
Filename string `form:"fileName" binding:"required"`
PartNo int `form:"partNo" binding:"required"`
ChannelID int64 `form:"channelId" binding:"required"`
Encrypted bool `form:"encrypted"`
}
type UploadPartOut struct {

View file

@ -10,6 +10,7 @@ import (
"strconv"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/database"
"github.com/divyam234/teldrive/pkg/models"
@ -137,7 +138,7 @@ func getTGMessages(ctx context.Context, client *telegram.Client, parts []schemas
return messages, nil
}
func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
func getParts(ctx context.Context, cipher *crypt.Cipher, client *telegram.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) {
parts := []types.Part{}
@ -160,7 +161,12 @@ func getParts(ctx context.Context, client *telegram.Client, file *schemas.FileOu
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location := document.AsInputDocumentFileLocation()
parts = append(parts, types.Part{Location: location, Start: 0, End: document.Size - 1})
end := document.Size - 1
if cipher != nil {
end, _ = cipher.DecryptedSize(document.Size)
end -= 1
}
parts = append(parts, types.Part{Location: location, Start: 0, End: end, Size: document.Size})
}
cache.GetCache().Set(key, &parts, 3600)
return parts, nil
@ -187,12 +193,14 @@ func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part {
Location: parts[firstChunk].Location,
Start: startInFirstChunk,
End: endInLastChunk,
Size: parts[firstChunk].Size,
})
} else {
validParts = append(validParts, types.Part{
Location: parts[firstChunk].Location,
Start: startInFirstChunk,
End: parts[firstChunk].End,
Size: parts[firstChunk].Size,
})
// Add valid parts from any chunks in between.
@ -201,6 +209,7 @@ func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part {
Location: parts[i].Location,
Start: 0,
End: parts[i].End,
Size: parts[i].Size,
})
}
@ -209,6 +218,7 @@ func rangedParts(parts []types.Part, startByte, endByte int64) []types.Part {
Location: parts[lastChunk].Location,
Start: 0,
End: endInLastChunk,
Size: parts[lastChunk].Size,
})
}

View file

@ -13,6 +13,7 @@ import (
cnf "github.com/divyam234/teldrive/config"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/http_range"
"github.com/divyam234/teldrive/internal/md5"
"github.com/divyam234/teldrive/internal/reader"
@ -96,6 +97,7 @@ func (fs *FileService) CreateFile(c *gin.Context) (*schemas.FileOut, *types.AppE
fileDB.Type = fileIn.Type
fileDB.UserID = userId
fileDB.Status = "active"
fileDB.Encrypted = fileIn.Encrypted
if err := fs.Db.Create(&fileDB).Error; err != nil {
pgErr := err.(*pgconn.PgError)
@ -550,7 +552,15 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
config := cnf.GetConfig()
var token, channelUser string
var (
token, channelUser string
cipher *crypt.Cipher
lr io.ReadCloser
)
if file.Encrypted {
cipher, _ = crypt.NewCipher(config.EncryptionKey, config.EncryptionSalt)
}
if config.LazyStreamBots {
tgc.Workers.Set(tokens, file.ChannelID)
@ -559,12 +569,16 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
channelUser = strings.Split(token, ":")[0]
if r.Method != "HEAD" {
tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
parts, err := getParts(c, client, file, channelUser)
parts, err := getParts(c, cipher, client, file, channelUser)
if err != nil {
return err
}
parts = rangedParts(parts, start, end)
lr, _ := reader.NewLinearReader(c, client, parts, contentLength)
if file.Encrypted {
lr, _ = reader.NewDecryptedReader(c, client, parts, cipher, contentLength)
} else {
lr, _ = reader.NewLinearReader(c, client, parts, contentLength)
}
io.CopyN(w, lr, contentLength)
return nil
})
@ -599,13 +613,20 @@ func (fs *FileService) GetFileStream(c *gin.Context) {
}
if r.Method != "HEAD" {
parts, err := getParts(c, client.Tg, file, channelUser)
parts, err := getParts(c, cipher, client.Tg, file, channelUser)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
parts = rangedParts(parts, start, end)
lr, _ := reader.NewLinearReader(c, client.Tg, parts, contentLength)
if file.Encrypted {
lr, _ = reader.NewDecryptedReader(c, client.Tg, parts, cipher, contentLength)
} else {
lr, _ = reader.NewLinearReader(c, client.Tg, parts, contentLength)
}
io.CopyN(w, lr, contentLength)
}
}

View file

@ -9,6 +9,7 @@ import (
"time"
cnf "github.com/divyam234/teldrive/config"
"github.com/divyam234/teldrive/internal/crypt"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/mapper"
"github.com/divyam234/teldrive/pkg/schemas"
@ -105,7 +106,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
uploadId := c.Param("id")
file := c.Request.Body
fileStream := c.Request.Body
fileSize := c.Request.ContentLength
@ -144,11 +145,19 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty
return err
}
config := cnf.GetConfig()
if uploadQuery.Encrypted {
cipher, _ := crypt.NewCipher(config.EncryptionKey, config.EncryptionSalt)
fileSize = cipher.EncryptedSize(fileSize)
fileStream, _ = cipher.EncryptData(fileStream)
}
api := client.API()
u := uploader.NewUploader(api).WithThreads(16).WithPartSize(512 * 1024)
upload, err := u.Upload(c, uploader.NewUpload(fileName, file, fileSize))
upload, err := u.Upload(c, uploader.NewUpload(fileName, fileStream, fileSize))
if err != nil {
return err

View file

@ -15,6 +15,7 @@ type Part struct {
Location *tg.InputDocumentFileLocation
Start int64
End int64
Size int64
}
type JWTClaims struct {