teldrive/pkg/services/upload.go

278 lines
7.3 KiB
Go

package services
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/crypt"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/pool"
"github.com/tgdrive/teldrive/internal/tgc"
"go.uber.org/zap"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/uploader"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/pkg/mapper"
"github.com/tgdrive/teldrive/pkg/models"
)
const saltLength = 32
func (a *apiService) UploadsDelete(ctx context.Context, params api.UploadsDeleteParams) error {
if err := a.db.Where("upload_id = ?", params.ID).Delete(&models.Upload{}).Error; err != nil {
return &api.ErrorStatusCode{StatusCode: 500, Response: api.Error{Message: err.Error(), Code: 500}}
}
return nil
}
func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPartsByIdParams) ([]api.UploadPart, error) {
parts := []models.Upload{}
if err := a.db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", params.ID).
Where("created_at < ?", time.Now().UTC().Add(a.cnf.TG.Uploads.Retention)).
Find(&parts).Error; err != nil {
return nil, &apiError{err: err}
}
return mapper.ToUploadOut(parts), nil
}
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
userId, _ := auth.GetUser(ctx)
var stats []api.UploadStats
err := a.db.Raw(`
SELECT
dates.upload_date::date AS upload_date,
COALESCE(SUM(files.size), 0)::bigint AS total_uploaded
FROM
generate_series(CURRENT_DATE - INTERVAL '1 day' * @days, CURRENT_DATE, '1 day') AS dates(upload_date)
LEFT JOIN
teldrive.files AS files
ON
dates.upload_date = DATE_TRUNC('day', files.created_at)
WHERE
dates.upload_date >= CURRENT_DATE - INTERVAL '1 day' * @days and (files.type='file' or files.type is null) and (files.user_id=@userId or files.user_id is null)
GROUP BY
dates.upload_date
ORDER BY
dates.upload_date
`, sql.Named("days", params.Days-1), sql.Named("userId", userId)).Scan(&stats).Error
if err != nil {
return nil, &apiError{err: err}
}
return stats, nil
}
func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadReqWithContentType, params api.UploadsUploadParams) (*api.UploadPart, error) {
var (
channelId int64
err error
client *telegram.Client
token string
index int
channelUser string
out api.UploadPart
)
if params.Encrypted.Value && a.cnf.TG.Uploads.EncryptionKey == "" {
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
}
userId, session := auth.GetUser(ctx)
fileStream := req.Content.Data
fileSize := params.ContentLength
if params.ChannelId.Value == 0 {
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, err
}
} else {
channelId = params.ChannelId.Value
}
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
if err != nil {
return nil, err
}
if len(tokens) == 0 {
client, err = tgc.AuthClient(ctx, &a.cnf.TG, session)
if err != nil {
return nil, err
}
channelUser = strconv.FormatInt(userId, 10)
} else {
a.worker.Set(tokens, channelId)
token, index = a.worker.Next(channelId)
client, err = tgc.BotClient(ctx, a.boltdb, &a.cnf.TG, token)
if err != nil {
return nil, err
}
channelUser = strings.Split(token, ":")[0]
}
middlewares := tgc.NewMiddleware(&a.cnf.TG, tgc.WithFloodWait(),
tgc.WithRecovery(ctx),
tgc.WithRetry(a.cnf.TG.Uploads.MaxRetries),
tgc.WithRateLimit())
uploadPool := pool.NewPool(client, int64(a.cnf.TG.PoolSize), middlewares...)
defer uploadPool.Close()
logger := logging.FromContext(ctx)
logger.Debug("uploading chunk",
zap.String("fileName", params.FileName),
zap.String("partName", params.PartName),
zap.String("bot", channelUser),
zap.Int("botNo", index),
zap.Int("chunkNo", params.PartNo),
zap.Int64("partSize", fileSize),
)
err = tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
if err != nil {
return err
}
var salt string
if params.Encrypted.Value {
//gen random Salt
salt, _ = generateRandomSalt()
cipher, err := crypt.NewCipher(a.cnf.TG.Uploads.EncryptionKey, salt)
if err != nil {
return err
}
fileSize = crypt.EncryptedSize(fileSize)
fileStream, err = cipher.EncryptData(fileStream)
if err != nil {
return err
}
}
client := uploadPool.Default(ctx)
u := uploader.NewUploader(client).WithThreads(a.cnf.TG.Uploads.Threads).WithPartSize(512 * 1024)
upload, err := u.Upload(ctx, uploader.NewUpload(params.PartName, fileStream, fileSize))
if err != nil {
return err
}
document := message.UploadedDocument(upload).Filename(params.PartName).ForceFile(true)
sender := message.NewSender(client)
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ChannelID,
AccessHash: channel.AccessHash})
res, err := target.Media(ctx, document)
if err != nil {
return err
}
updates := res.(*tg.Updates)
var message *tg.Message
for _, update := range updates.Updates {
channelMsg, ok := update.(*tg.UpdateNewChannelMessage)
if ok {
message = channelMsg.Message.(*tg.Message)
break
}
}
if message.ID == 0 {
return fmt.Errorf("upload failed")
}
partUpload := &models.Upload{
Name: params.PartName,
UploadId: params.ID,
PartId: message.ID,
ChannelID: channelId,
Size: fileSize,
PartNo: int(params.PartNo),
UserId: userId,
Encrypted: params.Encrypted.Value,
Salt: salt,
}
if err := a.db.Create(partUpload).Error; err != nil {
if message.ID != 0 {
client.ChannelsDeleteMessages(ctx, &tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: []int{message.ID}})
}
return err
}
msgs, _ := client.ChannelsGetMessages(ctx,
&tg.ChannelsGetMessagesRequest{Channel: channel, ID: []tg.InputMessageClass{&tg.InputMessageID{ID: message.ID}}})
if msgs != nil && len(msgs.(*tg.MessagesChannelMessages).Messages) == 0 {
return errors.New("upload failed")
}
out = api.UploadPart{
Name: partUpload.Name,
PartId: partUpload.PartId,
ChannelId: partUpload.ChannelID,
PartNo: partUpload.PartNo,
Size: partUpload.Size,
Encrypted: partUpload.Encrypted,
}
out.SetSalt(api.NewOptString(partUpload.Salt))
return nil
})
if err != nil {
logger.Debug("upload failed", zap.String("fileName", params.FileName),
zap.String("partName", params.PartName),
zap.Int("chunkNo", params.PartNo))
return nil, err
}
logger.Debug("upload finished", zap.String("fileName", params.FileName),
zap.String("partName", params.PartName),
zap.Int("chunkNo", params.PartNo))
return &out, nil
}
func generateRandomSalt() (string, error) {
randomBytes := make([]byte, saltLength)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
hasher := sha256.New()
hasher.Write(randomBytes)
hashedSalt := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return hashedSalt, nil
}