teldrive/pkg/services/file.go

557 lines
14 KiB
Go
Raw Normal View History

2023-08-07 03:32:46 +08:00
package services
import (
2023-09-27 17:36:12 +08:00
"context"
2023-08-07 03:32:46 +08:00
"encoding/base64"
"fmt"
"io"
2023-12-03 16:05:32 +08:00
"mime"
2023-08-07 03:32:46 +08:00
"net/http"
"strconv"
"strings"
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/database"
2023-12-03 16:05:32 +08:00
"github.com/divyam234/teldrive/internal/http_range"
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/internal/md5"
"github.com/divyam234/teldrive/internal/reader"
"github.com/divyam234/teldrive/internal/tgc"
"github.com/divyam234/teldrive/internal/utils"
"github.com/divyam234/teldrive/pkg/logging"
2023-12-03 03:47:23 +08:00
"github.com/divyam234/teldrive/pkg/mapper"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
2023-11-09 19:10:37 +08:00
"github.com/gotd/td/tg"
"go.uber.org/zap"
2023-08-07 03:32:46 +08:00
2023-12-22 03:13:26 +08:00
"github.com/jackc/pgx/v5/pgtype"
2023-08-07 03:32:46 +08:00
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type FileService struct {
db *gorm.DB
2024-02-13 00:02:55 +08:00
cnf *config.TGConfig
worker *tgc.StreamWorker
}
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
2024-02-13 00:02:55 +08:00
return &FileService{db: db, cnf: &cnf.TG, worker: worker}
2023-12-03 05:46:53 +08:00
}
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
2023-08-07 03:32:46 +08:00
2023-12-03 03:47:23 +08:00
var fileDB models.File
2023-08-07 03:32:46 +08:00
fileIn.Path = strings.TrimSpace(fileIn.Path)
if fileIn.Path != "" {
2023-12-30 23:17:40 +08:00
pathId, err := fs.getPathId(fileIn.Path, userId)
if err != nil || pathId == "" {
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
2023-08-07 03:32:46 +08:00
}
2023-12-30 23:17:40 +08:00
fileDB.ParentID = pathId
2023-08-07 03:32:46 +08:00
}
if fileIn.Type == "folder" {
2023-12-03 03:47:23 +08:00
fileDB.MimeType = "drive/folder"
2023-08-07 03:32:46 +08:00
var fullPath string
if fileIn.Path == "/" {
fullPath = "/" + fileIn.Name
} else {
fullPath = fileIn.Path + "/" + fileIn.Name
}
2023-12-03 03:47:23 +08:00
fileDB.Path = fullPath
fileDB.Depth = utils.IntPointer(len(strings.Split(fileIn.Path, "/")) - 1)
2023-08-07 03:32:46 +08:00
} else if fileIn.Type == "file" {
2023-12-03 03:47:23 +08:00
fileDB.Path = ""
channelId := fileIn.ChannelID
2023-10-02 15:39:54 +08:00
if fileIn.ChannelID == 0 {
var err error
channelId, err = GetDefaultChannel(c, fs.db, userId)
2023-10-02 15:39:54 +08:00
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
2023-10-02 15:39:54 +08:00
}
2023-09-20 03:20:44 +08:00
}
fileDB.ChannelID = &channelId
2023-12-03 07:03:16 +08:00
fileDB.MimeType = fileIn.MimeType
parts := models.Parts{}
for _, part := range fileIn.Parts {
parts = append(parts, models.Part{
ID: part.ID,
Salt: part.Salt,
2023-12-03 07:03:16 +08:00
})
2023-08-07 03:32:46 +08:00
2023-12-03 07:03:16 +08:00
}
fileDB.Parts = &parts
fileDB.Starred = false
fileDB.Size = &fileIn.Size
}
fileDB.Name = fileIn.Name
fileDB.Type = fileIn.Type
2023-12-03 03:47:23 +08:00
fileDB.UserID = userId
fileDB.Status = "active"
2023-12-08 05:46:06 +08:00
fileDB.Encrypted = fileIn.Encrypted
2023-08-07 03:32:46 +08:00
if err := fs.db.Create(&fileDB).Error; err != nil {
if database.IsKeyConflictErr(err) {
return nil, &types.AppError{Error: database.ErrKeyConflict, Code: http.StatusConflict}
}
return nil, &types.AppError{Error: err}
2023-08-07 03:32:46 +08:00
}
2023-12-03 07:03:16 +08:00
res := mapper.ToFileOut(fileDB)
2023-08-07 03:32:46 +08:00
return res, nil
2023-08-07 03:32:46 +08:00
}
func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileUpdate) (*schemas.FileOut, *types.AppError) {
var (
files []models.File
chain *gorm.DB
)
if update.Type == "folder" && update.Name != "" {
chain = fs.db.Raw("select * from teldrive.update_folder(?, ?, ?)", id, update.Name, userId).Scan(&files)
2023-08-14 04:58:06 +08:00
} else {
chain = fs.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", id).Updates(update)
2023-08-07 03:32:46 +08:00
}
if chain.Error != nil {
return nil, &types.AppError{Error: chain.Error}
}
if chain.RowsAffected == 0 {
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
2023-08-07 03:32:46 +08:00
}
return mapper.ToFileOut(files[0]), nil
2023-08-07 03:32:46 +08:00
}
func (fs *FileService) GetFileByID(id string) (*schemas.FileOutFull, *types.AppError) {
var file models.File
if err := fs.db.Where("id = ?", id).First(&file).Error; err != nil {
if database.IsRecordNotFoundErr(err) {
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
}
return nil, &types.AppError{Error: err}
2023-08-07 03:32:46 +08:00
}
return mapper.ToFileOutFull(file), nil
2023-08-07 03:32:46 +08:00
}
func (fs *FileService) ListFiles(userId int64, fquery *schemas.FileQuery) (*schemas.FileResponse, *types.AppError) {
2023-08-07 03:32:46 +08:00
2023-11-30 02:58:55 +08:00
var (
pathId string
err error
)
if fquery.Path != "" {
pathId, err = fs.getPathId(fquery.Path, userId)
2023-11-30 02:58:55 +08:00
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusNotFound}
2023-11-01 21:39:54 +08:00
}
}
query := fs.db.Limit(fquery.PerPage)
2023-08-26 06:16:20 +08:00
filter := &models.File{UserID: userId, Status: "active"}
2023-08-07 03:32:46 +08:00
setOrderFilter(query, fquery)
2023-08-07 03:32:46 +08:00
if fquery.Op == "list" {
2023-08-07 03:32:46 +08:00
query.Order("type DESC").Order(getOrder(fquery)).Where("parent_id = ?", pathId).
Model(filter)
2023-08-07 03:32:46 +08:00
} else if fquery.Op == "find" {
2023-08-07 03:32:46 +08:00
filter.Name = fquery.Name
filter.Type = fquery.Type
filter.ParentID = fquery.ParentID
filter.Starred = *fquery.Starred
filter.Path = fquery.Path
2023-08-07 03:32:46 +08:00
if fquery.Path != "" && fquery.Name != "" {
filter.ParentID = pathId
filter.Path = ""
2023-09-11 16:02:32 +08:00
}
2023-08-07 03:32:46 +08:00
query.Order("type DESC").Order(getOrder(fquery)).Where(fquery).
Model(&filter)
2023-11-30 02:58:55 +08:00
} else if fquery.Op == "search" {
2023-08-07 03:32:46 +08:00
query.Where("teldrive.get_tsquery(?) @@ teldrive.get_tsvector(name)", fquery.Search)
2023-08-07 03:32:46 +08:00
query.Order(getOrder(fquery)).
Model(filter)
2023-08-07 03:32:46 +08:00
}
if fquery.Path == "" {
query.Select("*,(select path from teldrive.files as f where f.id = files.parent_id) as parent_path")
}
files := []schemas.FileOut{}
2023-08-07 03:32:46 +08:00
query.Scan(&files)
2023-08-07 03:32:46 +08:00
token := ""
if len(files) == fquery.PerPage {
lastItem := files[len(files)-1]
token = utils.GetField(&lastItem, utils.CamelToPascalCase(fquery.Sort))
2023-08-07 03:32:46 +08:00
token = base64.StdEncoding.EncodeToString([]byte(token))
}
res := &schemas.FileResponse{Files: files, NextPageToken: token}
2023-08-07 03:32:46 +08:00
return res, nil
}
2023-12-25 06:06:14 +08:00
func (fs *FileService) getPathId(path string, userId int64) (string, error) {
2023-11-30 02:58:55 +08:00
var file models.File
2023-11-01 21:39:54 +08:00
if err := fs.db.Model(&models.File{}).Select("id").Where("path = ?", path).Where("user_id = ?", userId).
First(&file).Error; database.IsRecordNotFoundErr(err) {
return "", database.ErrNotFound
2023-11-01 21:39:54 +08:00
}
2023-11-30 02:58:55 +08:00
return file.ID, nil
}
func (fs *FileService) MakeDirectory(userId int64, payload *schemas.MkDir) (*schemas.FileOut, *types.AppError) {
var files []models.File
if err := fs.db.Raw("select * from teldrive.create_directories(?, ?)", userId, payload.Path).
2023-12-03 03:47:23 +08:00
Scan(&files).Error; err != nil {
return nil, &types.AppError{Error: err}
}
2023-12-03 07:03:16 +08:00
file := mapper.ToFileOut(files[0])
return file, nil
}
func (fs *FileService) MoveFiles(userId int64, payload *schemas.FileOperation) (*schemas.Message, *types.AppError) {
items := pgtype.Array[string]{
Elements: payload.Files,
Valid: true,
Dims: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}},
}
if err := fs.db.Exec("select * from teldrive.move_items(? , ? , ?)", items, payload.Destination, userId).Error; err != nil {
return nil, &types.AppError{Error: err}
}
return &schemas.Message{Message: "files moved"}, nil
}
func (fs *FileService) DeleteFiles(userId int64, payload *schemas.FileOperation) (*schemas.Message, *types.AppError) {
if err := fs.db.Exec("call teldrive.delete_files($1)", payload.Files).Error; err != nil {
return nil, &types.AppError{Error: err}
}
return &schemas.Message{Message: "files deleted"}, nil
}
func (fs *FileService) MoveDirectory(userId int64, payload *schemas.DirMove) (*schemas.Message, *types.AppError) {
if err := fs.db.Exec("select * from teldrive.move_directory(? , ? , ?)", payload.Source,
payload.Destination, userId).Error; err != nil {
return nil, &types.AppError{Error: err}
}
return &schemas.Message{Message: "directory moved"}, nil
}
2023-11-09 19:10:37 +08:00
func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppError) {
2023-11-09 19:10:37 +08:00
var payload schemas.Copy
if err := c.ShouldBindJSON(&payload); err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
2023-11-09 19:10:37 +08:00
}
userId, session := GetUserAuth(c)
2023-11-09 19:10:37 +08:00
client, _ := tgc.AuthClient(c, fs.cnf, session)
2023-11-09 19:10:37 +08:00
var res []models.File
fs.db.Model(&models.File{}).Where("id = ?", payload.ID).Find(&res)
2023-11-09 19:10:37 +08:00
2023-12-03 07:03:16 +08:00
file := mapper.ToFileOutFull(res[0])
2023-11-09 19:10:37 +08:00
newIds := models.Parts{}
err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error {
2023-11-09 19:10:37 +08:00
user := strconv.FormatInt(userId, 10)
2023-12-03 03:47:23 +08:00
messages, err := getTGMessages(c, client, file.Parts, file.ChannelID, user)
2023-11-09 19:10:37 +08:00
if err != nil {
return err
}
2023-12-03 03:47:23 +08:00
channel, err := GetChannelById(c, client, file.ChannelID, user)
2023-11-09 19:10:37 +08:00
if err != nil {
return err
}
for _, message := range messages.Messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
2023-12-03 03:47:23 +08:00
id, _ := randInt64()
2023-11-09 19:10:37 +08:00
request := tg.MessagesSendMediaRequest{
Silent: true,
Peer: &tg.InputPeerChannel{ChannelID: channel.ChannelID, AccessHash: channel.AccessHash},
Media: &tg.InputMediaDocument{ID: document.AsInput()},
RandomID: id,
}
res, err := client.API().MessagesSendMedia(c, &request)
if err != nil {
return err
}
updates := res.(*tg.Updates)
var msg *tg.Message
for _, update := range updates.Updates {
channelMsg, ok := update.(*tg.UpdateNewChannelMessage)
if ok {
msg = channelMsg.Message.(*tg.Message)
break
}
}
newIds = append(newIds, models.Part{ID: int64(msg.ID)})
}
return nil
})
if err != nil {
return nil, &types.AppError{Error: err}
2023-11-09 19:10:37 +08:00
}
var destRes []models.File
if err := fs.db.Raw("select * from teldrive.create_directories(?, ?)", userId, payload.Destination).Scan(&destRes).Error; err != nil {
return nil, &types.AppError{Error: err}
2023-11-09 19:10:37 +08:00
}
dest := destRes[0]
dbFile := models.File{}
dbFile.Name = payload.Name
2023-12-03 03:47:23 +08:00
dbFile.Size = &file.Size
2023-11-09 19:10:37 +08:00
dbFile.Type = file.Type
dbFile.MimeType = file.MimeType
dbFile.Parts = &newIds
dbFile.UserID = userId
2023-12-03 03:47:23 +08:00
dbFile.Starred = false
2023-11-09 19:10:37 +08:00
dbFile.Status = "active"
dbFile.ParentID = dest.ID
2023-12-03 03:47:23 +08:00
dbFile.ChannelID = &file.ChannelID
dbFile.Encrypted = file.Encrypted
2023-11-09 19:10:37 +08:00
if err := fs.db.Create(&dbFile).Error; err != nil {
return nil, &types.AppError{Error: err}
2023-11-09 19:10:37 +08:00
}
return mapper.ToFileOut(dbFile), nil
2023-11-02 03:57:07 +08:00
}
2023-08-14 04:58:06 +08:00
func (fs *FileService) GetFileStream(c *gin.Context) {
2023-08-14 04:58:06 +08:00
w := c.Writer
2023-08-14 04:58:06 +08:00
r := c.Request
fileID := c.Param("fileID")
2023-09-20 03:20:44 +08:00
authHash := c.Query("hash")
if authHash == "" {
http.Error(w, "missing hash param", http.StatusBadRequest)
2023-09-20 03:20:44 +08:00
return
}
cache := cache.FromContext(c)
2023-08-07 03:32:46 +08:00
session, err := getSessionByHash(fs.db, cache, authHash)
2023-08-07 03:32:46 +08:00
2023-08-14 04:58:06 +08:00
if err != nil {
2023-11-16 23:21:35 +08:00
http.Error(w, "invalid hash", http.StatusBadRequest)
2023-09-20 03:20:44 +08:00
return
}
file := &schemas.FileOutFull{}
2023-11-02 21:51:30 +08:00
key := fmt.Sprintf("files:%s", fileID)
err = cache.Get(key, file)
var appErr *types.AppError
2023-09-20 03:20:44 +08:00
if err != nil {
file, appErr = fs.GetFileByID(fileID)
if appErr != nil {
http.Error(w, appErr.Error.Error(), http.StatusBadRequest)
2023-09-20 03:20:44 +08:00
return
}
cache.Set(key, file, 0)
2023-09-20 03:20:44 +08:00
}
2023-08-07 03:32:46 +08:00
2023-09-27 17:36:12 +08:00
c.Header("Accept-Ranges", "bytes")
2023-08-07 03:32:46 +08:00
2023-08-14 04:58:06 +08:00
var start, end int64
2023-08-07 03:32:46 +08:00
2023-08-14 04:58:06 +08:00
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
start = 0
end = file.Size - 1
w.WriteHeader(http.StatusOK)
} else {
2023-12-04 17:28:28 +08:00
ranges, err := http_range.Parse(rangeHeader, file.Size)
2023-12-03 16:05:32 +08:00
if err == http_range.ErrNoOverlap {
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size))
http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable)
return
}
2023-08-07 03:32:46 +08:00
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
2023-12-03 16:05:32 +08:00
if len(ranges) > 1 {
http.Error(w, "multiple ranges are not supported", http.StatusRequestedRangeNotSatisfiable)
return
}
2023-08-14 04:58:06 +08:00
start = ranges[0].Start
end = ranges[0].End
2023-09-27 17:36:12 +08:00
c.Header("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size))
2023-08-14 04:58:06 +08:00
w.WriteHeader(http.StatusPartialContent)
}
2023-08-07 03:32:46 +08:00
2023-08-14 04:58:06 +08:00
contentLength := end - start + 1
2023-08-07 03:32:46 +08:00
2023-09-14 15:17:55 +08:00
mimeType := file.MimeType
if mimeType == "" {
mimeType = "application/octet-stream"
}
2023-09-27 17:36:12 +08:00
c.Header("Content-Type", mimeType)
2023-08-07 03:32:46 +08:00
2023-09-27 17:36:12 +08:00
c.Header("Content-Length", strconv.FormatInt(contentLength, 10))
2023-11-16 23:21:35 +08:00
c.Header("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(file.ID+strconv.FormatInt(file.Size, 10))))
2023-09-27 17:36:12 +08:00
c.Header("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat))
2023-08-14 04:58:06 +08:00
2023-08-26 18:27:44 +08:00
disposition := "inline"
if c.Query("d") == "1" {
disposition = "attachment"
}
2023-12-03 16:05:32 +08:00
c.Header("Content-Disposition", mime.FormatMediaType(disposition, map[string]string{"filename": file.Name}))
2023-08-14 04:58:06 +08:00
tokens, err := getBotsToken(c, fs.db, session.UserId, file.ChannelID)
2023-08-14 04:58:06 +08:00
logger := logging.FromContext(c)
2023-08-14 04:58:06 +08:00
if err != nil {
logger.Error("failed to get bots", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
2023-08-14 04:58:06 +08:00
return
}
2023-08-26 18:27:44 +08:00
2023-12-08 05:46:06 +08:00
var (
channelUser string
lr io.ReadCloser
2023-12-08 05:46:06 +08:00
)
var client *tgc.Client
2023-08-14 04:58:06 +08:00
if fs.cnf.DisableStreamBots || len(tokens) == 0 {
tgClient, _ := tgc.AuthClient(c, fs.cnf, session.Session)
2023-12-25 06:06:14 +08:00
client, err = fs.worker.UserWorker(tgClient, session.UserId)
if err != nil {
logger.Error("file stream", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelUser = strconv.FormatInt(session.UserId, 10)
2023-09-24 04:26:04 +08:00
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "user", channelUser, "start", start,
"end", end, "fileSize", file.Size)
2023-09-24 04:26:04 +08:00
} else {
var index int
limit := min(len(tokens), fs.cnf.BgBotsLimit)
2023-09-27 17:36:12 +08:00
fs.worker.Set(tokens[:limit], file.ChannelID)
2023-09-27 17:36:12 +08:00
client, index, err = fs.worker.Next(file.ChannelID)
2023-09-27 17:36:12 +08:00
if err != nil {
logger.Error("file stream", zap.Error(err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
2023-11-16 23:21:35 +08:00
}
channelUser = strings.Split(tokens[index], ":")[0]
logger.Debugw("requesting file", "name", file.Name, "bot", channelUser, "botNo", index, "start", start,
"end", end, "fileSize", file.Size)
}
2023-11-16 23:21:35 +08:00
if r.Method != "HEAD" {
parts, err := getParts(c, client.Tg, file, channelUser)
if err != nil {
logger.Error("file stream", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
2023-12-08 05:46:06 +08:00
parts = rangedParts(parts, start, end)
2023-12-08 05:46:06 +08:00
if file.Encrypted {
2024-02-13 00:02:55 +08:00
lr, _ = reader.NewDecryptedReader(c, client.Tg, parts, contentLength, fs.cnf.Uploads.EncryptionKey)
} else {
lr, _ = reader.NewLinearReader(c, client.Tg, parts, contentLength)
}
2023-12-08 05:46:06 +08:00
2024-02-04 15:53:27 +08:00
io.CopyN(w, lr, contentLength)
2023-08-13 04:15:19 +08:00
}
}
func setOrderFilter(query *gorm.DB, fquery *schemas.FileQuery) *gorm.DB {
if fquery.NextPageToken != "" {
sortColumn := utils.CamelToSnake(fquery.Sort)
2023-08-13 04:15:19 +08:00
tokenValue, err := base64.StdEncoding.DecodeString(fquery.NextPageToken)
2023-08-07 03:32:46 +08:00
if err == nil {
if fquery.Order == "asc" {
2023-08-26 06:16:20 +08:00
return query.Where(fmt.Sprintf("%s > ?", sortColumn), string(tokenValue))
2023-08-07 03:32:46 +08:00
} else {
2023-08-26 06:16:20 +08:00
return query.Where(fmt.Sprintf("%s < ?", sortColumn), string(tokenValue))
2023-08-07 03:32:46 +08:00
}
}
}
2023-08-26 06:16:20 +08:00
return query
2023-08-07 03:32:46 +08:00
}
func getOrder(fquery *schemas.FileQuery) clause.OrderByColumn {
sortColumn := utils.CamelToSnake(fquery.Sort)
2023-08-07 03:32:46 +08:00
2023-12-03 16:05:32 +08:00
return clause.OrderByColumn{Column: clause.Column{Name: sortColumn},
Desc: fquery.Order == "desc"}
2023-08-07 03:32:46 +08:00
}