teldrive/pkg/services/upload.go

267 lines
6.7 KiB
Go
Raw Normal View History

2023-08-13 04:15:19 +08:00
package services
import (
2023-08-24 02:40:40 +08:00
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
2023-12-18 11:29:03 +08:00
"errors"
2023-08-13 04:15:19 +08:00
"net/http"
2023-09-20 03:20:44 +08:00
"strconv"
"strings"
2023-11-01 01:46:52 +08:00
"time"
2023-08-13 04:15:19 +08:00
"github.com/divyam234/teldrive/config"
2023-12-08 05:46:06 +08:00
"github.com/divyam234/teldrive/internal/crypt"
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/pkg/mapper"
"github.com/divyam234/teldrive/pkg/schemas"
"go.uber.org/zap"
2023-08-13 04:15:19 +08:00
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/pkg/types"
2023-08-13 04:15:19 +08:00
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/pkg/models"
2023-08-13 04:15:19 +08:00
"github.com/gin-gonic/gin"
2023-09-20 03:20:44 +08:00
"github.com/gotd/td/telegram"
2023-08-13 04:15:19 +08:00
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/uploader"
"github.com/gotd/td/tg"
"gorm.io/gorm"
)
const saltLength = 32
2023-08-13 04:15:19 +08:00
type UploadService struct {
Db *gorm.DB
log *zap.Logger
worker *tgc.UploadWorker
2023-08-13 04:15:19 +08:00
}
func NewUploadService(db *gorm.DB, logger *zap.Logger) *UploadService {
return &UploadService{Db: db, log: logger.Named("uploads"),
worker: &tgc.UploadWorker{}}
2023-12-03 05:46:53 +08:00
}
func generateRandomSalt() (string, error) {
randomBytes := make([]byte, saltLength)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
hasher := sha256.New()
hasher.Write(randomBytes)
hashedSalt := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return hashedSalt, nil
}
func (us *UploadService) logAndReturn(context string, err error, errCode int) *types.AppError {
us.log.Error(context, zap.Error(err))
return &types.AppError{Error: err, Code: errCode}
}
2023-08-13 04:15:19 +08:00
func (us *UploadService) GetUploadFileById(c *gin.Context) (*schemas.UploadOut, *types.AppError) {
uploadId := c.Param("id")
parts := []schemas.UploadPartOut{}
2023-11-01 01:46:52 +08:00
if err := us.Db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", uploadId).
Where("created_at >= ?", time.Now().UTC().AddDate(0, 0, -config.GetConfig().UploadRetention)).
2023-11-01 01:46:52 +08:00
Find(&parts).Error; err != nil {
return nil, us.logAndReturn("get upload", err, http.StatusInternalServerError)
2023-08-13 04:15:19 +08:00
}
return &schemas.UploadOut{Parts: parts}, nil
}
2023-12-03 03:47:23 +08:00
func (us *UploadService) DeleteUploadFile(c *gin.Context) (*schemas.Message, *types.AppError) {
2023-08-13 04:15:19 +08:00
uploadId := c.Param("id")
if err := us.Db.Where("upload_id = ?", uploadId).Delete(&models.Upload{}).Error; err != nil {
return nil, us.logAndReturn("delete upload", err, http.StatusInternalServerError)
2023-08-13 04:15:19 +08:00
}
2023-12-03 03:47:23 +08:00
return &schemas.Message{Message: "upload deleted"}, nil
2023-08-13 04:15:19 +08:00
}
2023-11-07 17:06:14 +08:00
func (us *UploadService) CreateUploadPart(c *gin.Context) (*schemas.UploadPartOut, *types.AppError) {
userId, _ := getUserAuth(c)
var payload schemas.UploadPart
if err := c.ShouldBindJSON(&payload); err != nil {
2023-12-04 03:21:30 +08:00
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
2023-11-07 17:06:14 +08:00
}
partUpload := &models.Upload{
2023-12-03 14:52:25 +08:00
Name: payload.Name,
UploadId: payload.UploadId,
PartId: payload.PartId,
ChannelID: payload.ChannelID,
Size: payload.Size,
PartNo: payload.PartNo,
UserId: userId,
2023-11-07 17:06:14 +08:00
}
if err := us.Db.Create(partUpload).Error; err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
2023-12-03 03:47:23 +08:00
out := mapper.ToUploadOut(partUpload)
2023-11-07 17:06:14 +08:00
return out, nil
}
2023-08-13 04:15:19 +08:00
func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *types.AppError) {
2023-11-16 23:21:35 +08:00
var (
uploadQuery schemas.UploadQuery
channelId int64
err error
client *telegram.Client
token string
index int
2023-11-16 23:21:35 +08:00
channelUser string
out *schemas.UploadPartOut
)
2023-08-13 04:15:19 +08:00
uploadQuery.PartNo = 1
if err := c.ShouldBindQuery(&uploadQuery); err != nil {
return nil, us.logAndReturn("UploadFile", err, http.StatusBadRequest)
2023-08-13 04:15:19 +08:00
}
2023-12-18 11:29:03 +08:00
var encryptedKey string
if uploadQuery.Encrypted {
encryptedKey = config.GetConfig().EncryptionKey
if encryptedKey == "" {
return nil, us.logAndReturn("UploadFile", errors.New("encryption key not set"), http.StatusInternalServerError)
}
}
2023-11-01 02:03:16 +08:00
userId, session := getUserAuth(c)
2023-08-13 04:15:19 +08:00
uploadId := c.Param("id")
2023-12-08 05:46:06 +08:00
fileStream := c.Request.Body
2023-08-13 04:15:19 +08:00
2023-08-24 02:40:40 +08:00
fileSize := c.Request.ContentLength
2023-12-18 19:12:49 +08:00
defer c.Request.Body.Close()
2023-11-16 23:21:35 +08:00
if uploadQuery.ChannelID == 0 {
channelId, err = GetDefaultChannel(c, userId)
if err != nil {
return nil, us.logAndReturn("uploadFile", err, http.StatusInternalServerError)
2023-11-16 23:21:35 +08:00
}
} else {
channelId = uploadQuery.ChannelID
}
2023-12-03 03:47:23 +08:00
tokens, err := getBotsToken(c, userId, channelId)
2023-09-20 03:20:44 +08:00
if err != nil {
return nil, us.logAndReturn("uploadFile", err, http.StatusInternalServerError)
2023-09-20 03:20:44 +08:00
}
2023-08-13 04:15:19 +08:00
2023-09-20 03:20:44 +08:00
if len(tokens) == 0 {
2023-11-25 12:31:29 +08:00
client, _ = tgc.UserLogin(c, session)
2023-09-20 03:20:44 +08:00
channelUser = strconv.FormatInt(userId, 10)
} else {
us.worker.Set(tokens, channelId)
token, index = us.worker.Next(channelId)
2023-11-25 12:31:29 +08:00
client, _ = tgc.BotLogin(c, token)
2023-09-20 03:20:44 +08:00
channelUser = strings.Split(token, ":")[0]
}
us.log.Debug("uploading file", zap.String("fileName", uploadQuery.FileName),
zap.String("partName", uploadQuery.PartName),
zap.String("bot", channelUser), zap.Int("botNo", index),
zap.Int("chunkNo", uploadQuery.PartNo), zap.Int64("partSize", fileSize))
err = tgc.RunWithAuth(c, us.log, client, token, func(ctx context.Context) error {
2023-09-20 03:20:44 +08:00
channel, err := GetChannelById(ctx, client, channelId, channelUser)
if err != nil {
return err
}
var salt string
2023-12-08 05:46:06 +08:00
if uploadQuery.Encrypted {
//gen random Salt
salt, _ = generateRandomSalt()
2023-12-18 11:29:03 +08:00
cipher, _ := crypt.NewCipher(encryptedKey, salt)
fileSize = crypt.EncryptedSize(fileSize)
2023-12-08 05:46:06 +08:00
fileStream, _ = cipher.EncryptData(fileStream)
}
2023-09-20 03:20:44 +08:00
api := client.API()
2023-08-13 04:15:19 +08:00
2023-11-25 12:31:29 +08:00
u := uploader.NewUploader(api).WithThreads(16).WithPartSize(512 * 1024)
2023-09-20 03:20:44 +08:00
upload, err := u.Upload(c, uploader.NewUpload(uploadQuery.PartName, fileStream, fileSize))
2023-08-13 04:15:19 +08:00
2023-08-24 02:40:40 +08:00
if err != nil {
return err
}
2023-08-13 04:15:19 +08:00
document := message.UploadedDocument(upload).Filename(uploadQuery.PartName).ForceFile(true)
2023-08-13 04:15:19 +08:00
2023-08-24 02:40:40 +08:00
sender := message.NewSender(client.API())
2023-08-13 04:15:19 +08:00
2023-09-20 03:20:44 +08:00
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ChannelID,
AccessHash: channel.AccessHash})
2023-08-24 02:40:40 +08:00
2023-09-20 03:20:44 +08:00
res, err := target.Media(c, document)
2023-08-13 04:15:19 +08:00
2023-08-24 02:40:40 +08:00
if err != nil {
return err
}
2023-08-13 04:15:19 +08:00
2023-08-24 02:40:40 +08:00
updates := res.(*tg.Updates)
2023-08-13 04:15:19 +08:00
2023-11-06 19:49:49 +08:00
var message *tg.Message
2023-08-24 02:40:40 +08:00
2023-11-06 19:49:49 +08:00
for _, update := range updates.Updates {
channelMsg, ok := update.(*tg.UpdateNewChannelMessage)
if ok {
message = channelMsg.Message.(*tg.Message)
break
}
2023-09-20 03:20:44 +08:00
}
partUpload := &models.Upload{
Name: uploadQuery.PartName,
2023-12-03 14:52:25 +08:00
UploadId: uploadId,
PartId: message.ID,
ChannelID: channelId,
Size: fileSize,
PartNo: uploadQuery.PartNo,
UserId: userId,
2023-12-08 06:05:40 +08:00
Encrypted: uploadQuery.Encrypted,
Salt: salt,
2023-09-20 03:20:44 +08:00
}
if err := us.Db.Create(partUpload).Error; err != nil {
2023-12-04 03:21:30 +08:00
return err
2023-09-20 03:20:44 +08:00
}
2023-12-03 03:47:23 +08:00
out = mapper.ToUploadOut(partUpload)
2023-09-20 03:20:44 +08:00
2023-08-24 02:40:40 +08:00
return nil
})
2023-08-13 04:15:19 +08:00
if err != nil {
return nil, us.logAndReturn("uploadFile", err, http.StatusInternalServerError)
2023-08-13 04:15:19 +08:00
}
us.log.Debug("upload finished", zap.String("fileName", uploadQuery.FileName),
zap.String("partName", uploadQuery.PartName),
zap.Int("chunkNo", uploadQuery.PartNo))
2023-08-13 04:15:19 +08:00
return out, nil
}