mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-04 05:26:02 +08:00
278 lines
7.3 KiB
Go
278 lines
7.3 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/tgdrive/teldrive/internal/api"
|
|
"github.com/tgdrive/teldrive/internal/auth"
|
|
"github.com/tgdrive/teldrive/internal/crypt"
|
|
"github.com/tgdrive/teldrive/internal/logging"
|
|
"github.com/tgdrive/teldrive/internal/pool"
|
|
"github.com/tgdrive/teldrive/internal/tgc"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/gotd/td/telegram"
|
|
"github.com/gotd/td/telegram/message"
|
|
"github.com/gotd/td/telegram/uploader"
|
|
"github.com/gotd/td/tg"
|
|
"github.com/tgdrive/teldrive/pkg/mapper"
|
|
"github.com/tgdrive/teldrive/pkg/models"
|
|
)
|
|
|
|
const saltLength = 32
|
|
|
|
func (a *apiService) UploadsDelete(ctx context.Context, params api.UploadsDeleteParams) error {
|
|
if err := a.db.Where("upload_id = ?", params.ID).Delete(&models.Upload{}).Error; err != nil {
|
|
return &api.ErrorStatusCode{StatusCode: 500, Response: api.Error{Message: err.Error(), Code: 500}}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPartsByIdParams) ([]api.UploadPart, error) {
|
|
parts := []models.Upload{}
|
|
if err := a.db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", params.ID).
|
|
Where("created_at < ?", time.Now().UTC().Add(a.cnf.TG.Uploads.Retention)).
|
|
Find(&parts).Error; err != nil {
|
|
return nil, &apiError{err: err}
|
|
}
|
|
return mapper.ToUploadOut(parts), nil
|
|
}
|
|
|
|
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
|
|
userId, _ := auth.GetUser(ctx)
|
|
var stats []api.UploadStats
|
|
err := a.db.Raw(`
|
|
SELECT
|
|
dates.upload_date::date AS upload_date,
|
|
COALESCE(SUM(files.size), 0)::bigint AS total_uploaded
|
|
FROM
|
|
generate_series(CURRENT_DATE - INTERVAL '1 day' * @days, CURRENT_DATE, '1 day') AS dates(upload_date)
|
|
LEFT JOIN
|
|
teldrive.files AS files
|
|
ON
|
|
dates.upload_date = DATE_TRUNC('day', files.created_at)
|
|
WHERE
|
|
dates.upload_date >= CURRENT_DATE - INTERVAL '1 day' * @days and (files.type='file' or files.type is null) and (files.user_id=@userId or files.user_id is null)
|
|
GROUP BY
|
|
dates.upload_date
|
|
ORDER BY
|
|
dates.upload_date
|
|
`, sql.Named("days", params.Days-1), sql.Named("userId", userId)).Scan(&stats).Error
|
|
|
|
if err != nil {
|
|
return nil, &apiError{err: err}
|
|
|
|
}
|
|
return stats, nil
|
|
}
|
|
|
|
func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadReqWithContentType, params api.UploadsUploadParams) (*api.UploadPart, error) {
|
|
var (
|
|
channelId int64
|
|
err error
|
|
client *telegram.Client
|
|
token string
|
|
index int
|
|
channelUser string
|
|
out api.UploadPart
|
|
)
|
|
|
|
if params.Encrypted.Value && a.cnf.TG.Uploads.EncryptionKey == "" {
|
|
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
|
|
}
|
|
|
|
userId, session := auth.GetUser(ctx)
|
|
|
|
fileStream := req.Content.Data
|
|
|
|
fileSize := params.ContentLength
|
|
|
|
if params.ChannelId.Value == 0 {
|
|
channelId, err = getDefaultChannel(a.db, a.cache, userId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
channelId = params.ChannelId.Value
|
|
}
|
|
|
|
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(tokens) == 0 {
|
|
client, err = tgc.AuthClient(ctx, &a.cnf.TG, session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
channelUser = strconv.FormatInt(userId, 10)
|
|
} else {
|
|
a.worker.Set(tokens, channelId)
|
|
token, index = a.worker.Next(channelId)
|
|
client, err = tgc.BotClient(ctx, a.boltdb, &a.cnf.TG, token)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
channelUser = strings.Split(token, ":")[0]
|
|
}
|
|
|
|
middlewares := tgc.NewMiddleware(&a.cnf.TG, tgc.WithFloodWait(),
|
|
tgc.WithRecovery(ctx),
|
|
tgc.WithRetry(a.cnf.TG.Uploads.MaxRetries),
|
|
tgc.WithRateLimit())
|
|
|
|
uploadPool := pool.NewPool(client, int64(a.cnf.TG.PoolSize), middlewares...)
|
|
|
|
defer uploadPool.Close()
|
|
|
|
logger := logging.FromContext(ctx)
|
|
|
|
logger.Debug("uploading chunk",
|
|
zap.String("fileName", params.FileName),
|
|
zap.String("partName", params.PartName),
|
|
zap.String("bot", channelUser),
|
|
zap.Int("botNo", index),
|
|
zap.Int("chunkNo", params.PartNo),
|
|
zap.Int64("partSize", fileSize),
|
|
)
|
|
|
|
err = tgc.RunWithAuth(ctx, client, token, func(ctx context.Context) error {
|
|
|
|
channel, err := tgc.GetChannelById(ctx, client.API(), channelId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var salt string
|
|
|
|
if params.Encrypted.Value {
|
|
//gen random Salt
|
|
salt, _ = generateRandomSalt()
|
|
cipher, err := crypt.NewCipher(a.cnf.TG.Uploads.EncryptionKey, salt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fileSize = crypt.EncryptedSize(fileSize)
|
|
fileStream, err = cipher.EncryptData(fileStream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
client := uploadPool.Default(ctx)
|
|
|
|
u := uploader.NewUploader(client).WithThreads(a.cnf.TG.Uploads.Threads).WithPartSize(512 * 1024)
|
|
|
|
upload, err := u.Upload(ctx, uploader.NewUpload(params.PartName, fileStream, fileSize))
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
document := message.UploadedDocument(upload).Filename(params.PartName).ForceFile(true)
|
|
|
|
sender := message.NewSender(client)
|
|
|
|
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ChannelID,
|
|
AccessHash: channel.AccessHash})
|
|
|
|
res, err := target.Media(ctx, document)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
updates := res.(*tg.Updates)
|
|
|
|
var message *tg.Message
|
|
|
|
for _, update := range updates.Updates {
|
|
channelMsg, ok := update.(*tg.UpdateNewChannelMessage)
|
|
if ok {
|
|
message = channelMsg.Message.(*tg.Message)
|
|
break
|
|
}
|
|
}
|
|
|
|
if message.ID == 0 {
|
|
return fmt.Errorf("upload failed")
|
|
}
|
|
|
|
partUpload := &models.Upload{
|
|
Name: params.PartName,
|
|
UploadId: params.ID,
|
|
PartId: message.ID,
|
|
ChannelID: channelId,
|
|
Size: fileSize,
|
|
PartNo: int(params.PartNo),
|
|
UserId: userId,
|
|
Encrypted: params.Encrypted.Value,
|
|
Salt: salt,
|
|
}
|
|
|
|
if err := a.db.Create(partUpload).Error; err != nil {
|
|
if message.ID != 0 {
|
|
client.ChannelsDeleteMessages(ctx, &tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: []int{message.ID}})
|
|
}
|
|
return err
|
|
}
|
|
|
|
msgs, _ := client.ChannelsGetMessages(ctx,
|
|
&tg.ChannelsGetMessagesRequest{Channel: channel, ID: []tg.InputMessageClass{&tg.InputMessageID{ID: message.ID}}})
|
|
|
|
if msgs != nil && len(msgs.(*tg.MessagesChannelMessages).Messages) == 0 {
|
|
return errors.New("upload failed")
|
|
}
|
|
out = api.UploadPart{
|
|
Name: partUpload.Name,
|
|
PartId: partUpload.PartId,
|
|
ChannelId: partUpload.ChannelID,
|
|
PartNo: partUpload.PartNo,
|
|
Size: partUpload.Size,
|
|
Encrypted: partUpload.Encrypted,
|
|
}
|
|
out.SetSalt(api.NewOptString(partUpload.Salt))
|
|
return nil
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Debug("upload failed", zap.String("fileName", params.FileName),
|
|
zap.String("partName", params.PartName),
|
|
zap.Int("chunkNo", params.PartNo))
|
|
return nil, err
|
|
}
|
|
logger.Debug("upload finished", zap.String("fileName", params.FileName),
|
|
zap.String("partName", params.PartName),
|
|
zap.Int("chunkNo", params.PartNo))
|
|
return &out, nil
|
|
|
|
}
|
|
|
|
func generateRandomSalt() (string, error) {
|
|
randomBytes := make([]byte, saltLength)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
hasher := sha256.New()
|
|
hasher.Write(randomBytes)
|
|
hashedSalt := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
|
|
|
|
return hashedSalt, nil
|
|
}
|