mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-08 08:10:05 +08:00
231 lines
5.8 KiB
Go
231 lines
5.8 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/divyam234/teldrive/internal/crypt"
|
|
"github.com/divyam234/teldrive/internal/kv"
|
|
"github.com/divyam234/teldrive/internal/tgc"
|
|
"github.com/divyam234/teldrive/pkg/logging"
|
|
"github.com/divyam234/teldrive/pkg/mapper"
|
|
"github.com/divyam234/teldrive/pkg/schemas"
|
|
|
|
"github.com/divyam234/teldrive/pkg/types"
|
|
|
|
"github.com/divyam234/teldrive/internal/config"
|
|
"github.com/divyam234/teldrive/pkg/models"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gotd/td/telegram"
|
|
"github.com/gotd/td/telegram/message"
|
|
"github.com/gotd/td/telegram/uploader"
|
|
"github.com/gotd/td/tg"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const saltLength = 32
|
|
|
|
type UploadService struct {
|
|
db *gorm.DB
|
|
worker *tgc.UploadWorker
|
|
cnf *config.TGConfig
|
|
kv kv.KV
|
|
}
|
|
|
|
func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV) *UploadService {
|
|
return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv}
|
|
}
|
|
|
|
func (us *UploadService) GetUploadFileById(c *gin.Context) (*schemas.UploadOut, *types.AppError) {
|
|
uploadId := c.Param("id")
|
|
parts := []schemas.UploadPartOut{}
|
|
if err := us.db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", uploadId).
|
|
Where("created_at < ?", time.Now().UTC().Add(us.cnf.Uploads.Retention)).
|
|
Find(&parts).Error; err != nil {
|
|
return nil, &types.AppError{Error: err}
|
|
}
|
|
|
|
return &schemas.UploadOut{Parts: parts}, nil
|
|
}
|
|
|
|
func (us *UploadService) DeleteUploadFile(c *gin.Context) (*schemas.Message, *types.AppError) {
|
|
uploadId := c.Param("id")
|
|
if err := us.db.Where("upload_id = ?", uploadId).Delete(&models.Upload{}).Error; err != nil {
|
|
return nil, &types.AppError{Error: err}
|
|
}
|
|
return &schemas.Message{Message: "upload deleted"}, nil
|
|
}
|
|
|
|
func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *types.AppError) {
|
|
var (
|
|
uploadQuery schemas.UploadQuery
|
|
channelId int64
|
|
err error
|
|
client *telegram.Client
|
|
token string
|
|
index int
|
|
channelUser string
|
|
out *schemas.UploadPartOut
|
|
)
|
|
|
|
if err := c.ShouldBindQuery(&uploadQuery); err != nil {
|
|
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
|
|
}
|
|
|
|
if uploadQuery.Encrypted && us.cnf.Uploads.EncryptionKey == "" {
|
|
return nil, &types.AppError{Error: errors.New("encryption key not found"),
|
|
Code: http.StatusBadRequest}
|
|
}
|
|
|
|
userId, session := GetUserAuth(c)
|
|
|
|
uploadId := c.Param("id")
|
|
|
|
fileStream := c.Request.Body
|
|
|
|
fileSize := c.Request.ContentLength
|
|
|
|
defer c.Request.Body.Close()
|
|
|
|
if uploadQuery.ChannelID == 0 {
|
|
channelId, err = GetDefaultChannel(c, us.db, userId)
|
|
if err != nil {
|
|
return nil, &types.AppError{Error: err}
|
|
}
|
|
} else {
|
|
channelId = uploadQuery.ChannelID
|
|
}
|
|
|
|
tokens, err := getBotsToken(c, us.db, userId, channelId)
|
|
|
|
if err != nil {
|
|
return nil, &types.AppError{Error: err}
|
|
}
|
|
|
|
if len(tokens) == 0 {
|
|
client, _ = tgc.AuthClient(c, us.cnf, session)
|
|
channelUser = strconv.FormatInt(userId, 10)
|
|
} else {
|
|
us.worker.Set(tokens, channelId)
|
|
token, index = us.worker.Next(channelId)
|
|
client, _ = tgc.BotClient(c, us.kv, us.cnf, token)
|
|
channelUser = strings.Split(token, ":")[0]
|
|
}
|
|
|
|
logger := logging.FromContext(c)
|
|
|
|
logger.Debugw("uploading file", "fileName", uploadQuery.FileName,
|
|
"partName", uploadQuery.PartName,
|
|
"bot", channelUser, "botNo", index,
|
|
"chunkNo", uploadQuery.PartNo, "partSize", fileSize)
|
|
|
|
err = tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
|
|
|
|
channel, err := GetChannelById(ctx, client, channelId, channelUser)
|
|
|
|
if err != nil {
|
|
logger.Error("error", err)
|
|
return err
|
|
}
|
|
|
|
var salt string
|
|
|
|
if uploadQuery.Encrypted {
|
|
//gen random Salt
|
|
salt, _ = generateRandomSalt()
|
|
cipher, _ := crypt.NewCipher(us.cnf.Uploads.EncryptionKey, salt)
|
|
fileSize = crypt.EncryptedSize(fileSize)
|
|
fileStream, _ = cipher.EncryptData(fileStream)
|
|
}
|
|
|
|
api := client.API()
|
|
|
|
u := uploader.NewUploader(api).WithThreads(us.cnf.Uploads.Threads).WithPartSize(512 * 1024)
|
|
|
|
upload, err := u.Upload(c, uploader.NewUpload(uploadQuery.PartName, fileStream, fileSize))
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
document := message.UploadedDocument(upload).Filename(uploadQuery.PartName).ForceFile(true)
|
|
|
|
sender := message.NewSender(client.API())
|
|
|
|
target := sender.To(&tg.InputPeerChannel{ChannelID: channel.ChannelID,
|
|
AccessHash: channel.AccessHash})
|
|
|
|
res, err := target.Media(c, document)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
updates := res.(*tg.Updates)
|
|
|
|
var message *tg.Message
|
|
|
|
for _, update := range updates.Updates {
|
|
channelMsg, ok := update.(*tg.UpdateNewChannelMessage)
|
|
if ok {
|
|
message = channelMsg.Message.(*tg.Message)
|
|
break
|
|
}
|
|
}
|
|
|
|
partUpload := &models.Upload{
|
|
Name: uploadQuery.PartName,
|
|
UploadId: uploadId,
|
|
PartId: message.ID,
|
|
ChannelID: channelId,
|
|
Size: fileSize,
|
|
PartNo: uploadQuery.PartNo,
|
|
UserId: userId,
|
|
Encrypted: uploadQuery.Encrypted,
|
|
Salt: salt,
|
|
}
|
|
|
|
if err := us.db.Create(partUpload).Error; err != nil {
|
|
//delete uploaded part if upload fails
|
|
if message.ID != 0 {
|
|
api.ChannelsDeleteMessages(ctx, &tg.ChannelsDeleteMessagesRequest{Channel: channel, ID: []int{message.ID}})
|
|
}
|
|
return err
|
|
}
|
|
|
|
out = mapper.ToUploadOut(partUpload)
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, &types.AppError{Error: err}
|
|
}
|
|
|
|
logger.Debugw("upload finished", "fileName", uploadQuery.FileName,
|
|
"partName", uploadQuery.PartName,
|
|
"chunkNo", uploadQuery.PartNo)
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func generateRandomSalt() (string, error) {
|
|
randomBytes := make([]byte, saltLength)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
hasher := sha256.New()
|
|
hasher.Write(randomBytes)
|
|
hashedSalt := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
|
|
|
|
return hashedSalt, nil
|
|
}
|