teldrive/services/file.service.go
2023-09-21 18:28:32 +05:30

473 lines
13 KiB
Go

package services
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/divyam234/teldrive/database"
"github.com/divyam234/teldrive/mapper"
"github.com/divyam234/teldrive/models"
"github.com/divyam234/teldrive/schemas"
"github.com/divyam234/teldrive/utils"
"github.com/divyam234/teldrive/utils/kv"
"github.com/divyam234/teldrive/utils/md5"
"github.com/divyam234/teldrive/utils/reader"
"github.com/divyam234/teldrive/utils/tgc"
"github.com/gotd/td/telegram"
"github.com/divyam234/teldrive/types"
"github.com/gin-gonic/gin"
"github.com/jackc/pgx/v5/pgconn"
"github.com/mitchellh/mapstructure"
range_parser "github.com/quantumsheep/range-parser"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type FileService struct {
Db *gorm.DB
}
func (fs *FileService) CreateFile(c *gin.Context) (*schemas.FileOut, *types.AppError) {
userId, _ := getUserAuth(c)
var fileIn schemas.FileIn
if err := c.ShouldBindJSON(&fileIn); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
fileIn.Path = strings.TrimSpace(fileIn.Path)
if fileIn.Path != "" {
var parent models.File
if err := fs.Db.Where("type = ? AND path = ?", "folder", fileIn.Path).First(&parent).Error; err != nil {
return nil, &types.AppError{Error: errors.New("parent directory not found"), Code: http.StatusNotFound}
}
fileIn.ParentID = parent.ID
}
if fileIn.Type == "folder" {
fileIn.MimeType = "drive/folder"
var fullPath string
if fileIn.Path == "/" {
fullPath = "/" + fileIn.Name
} else {
fullPath = fileIn.Path + "/" + fileIn.Name
}
fileIn.Path = fullPath
fileIn.Depth = utils.IntPointer(len(strings.Split(fileIn.Path, "/")) - 1)
} else if fileIn.Type == "file" {
fileIn.Path = ""
channelId, err := GetDefaultChannel(c, userId)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}
}
fileIn.ChannelID = &channelId
}
fileIn.UserID = userId
fileIn.Starred = utils.BoolPointer(false)
fileIn.Status = "active"
fileDb := mapper.MapFileInToFile(fileIn)
if err := fs.Db.Create(&fileDb).Error; err != nil {
pgErr := err.(*pgconn.PgError)
if pgErr.Code == "23505" {
return nil, &types.AppError{Error: errors.New("file exists"), Code: http.StatusBadRequest}
}
return nil, &types.AppError{Error: errors.New("failed to create a file"), Code: http.StatusBadRequest}
}
res := mapper.MapFileToFileOut(fileDb)
return &res, nil
}
func (fs *FileService) UpdateFile(c *gin.Context) (*schemas.FileOut, *types.AppError) {
fileID := c.Param("fileID")
var fileUpdate schemas.FileIn
var files []models.File
if err := c.ShouldBindJSON(&fileUpdate); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
if fileUpdate.Type == "folder" && fileUpdate.Name != "" {
if err := fs.Db.Raw("select * from teldrive.update_folder(?, ?)", fileID, fileUpdate.Name).Scan(&files).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to update the file"), Code: http.StatusInternalServerError}
}
} else {
fileDb := mapper.MapFileInToFile(fileUpdate)
if err := fs.Db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", fileID).Updates(fileDb).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to update the file"), Code: http.StatusInternalServerError}
}
}
if len(files) == 0 {
return nil, &types.AppError{Error: errors.New("file not updated"), Code: http.StatusNotFound}
}
file := mapper.MapFileToFileOut(files[0])
key := kv.Key("files", fileID)
database.KV.Delete(key)
return &file, nil
}
func (fs *FileService) GetFileByID(c *gin.Context) (*schemas.FileOutFull, error) {
fileID := c.Param("fileID")
var file []models.File
fs.Db.Model(&models.File{}).Where("id = ?", fileID).Find(&file)
if len(file) == 0 {
return nil, errors.New("file not found")
}
return mapper.MapFileToFileOutFull(file[0]), nil
}
func (fs *FileService) ListFiles(c *gin.Context) (*schemas.FileResponse, *types.AppError) {
userId, _ := getUserAuth(c)
var pagingParams schemas.PaginationQuery
pagingParams.PerPage = 200
if err := c.ShouldBindQuery(&pagingParams); err != nil {
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
var sortingParams schemas.SortingQuery
sortingParams.Order = "asc"
sortingParams.Sort = "name"
if err := c.ShouldBindQuery(&sortingParams); err != nil {
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
var fileQuery schemas.FileQuery
fileQuery.Op = "list"
fileQuery.Status = "active"
fileQuery.UserID = userId
if err := c.ShouldBindQuery(&fileQuery); err != nil {
return nil, &types.AppError{Error: errors.New("invalid params"), Code: http.StatusBadRequest}
}
query := fs.Db.Model(&models.File{}).Limit(pagingParams.PerPage).
Where(map[string]interface{}{"user_id": userId, "status": "active"})
if fileQuery.Op == "list" {
if pathExists, message := fs.CheckIfPathExists(&fileQuery.Path); !pathExists {
return nil, &types.AppError{Error: errors.New(message), Code: http.StatusNotFound}
}
setOrderFilter(query, &pagingParams, &sortingParams)
query.Order("type DESC").Order(getOrder(sortingParams)).
Where("parent_id in (?)", fs.Db.Model(&models.File{}).Select("id").Where("path = ?", fileQuery.Path))
} else if fileQuery.Op == "find" {
filterQuery := map[string]interface{}{}
err := mapstructure.Decode(fileQuery, &filterQuery)
if err != nil {
return nil, &types.AppError{Error: err, Code: http.StatusBadRequest}
}
delete(filterQuery, "op")
if filterQuery["updated_at"] == nil {
delete(filterQuery, "updated_at")
}
if filterQuery["path"] != nil && filterQuery["name"] != nil {
query.Where("parent_id in (?)", fs.Db.Model(&models.File{}).Select("id").Where("path = ?", filterQuery["path"]))
delete(filterQuery, "path")
}
query.Order("type DESC").Order(getOrder(sortingParams)).Where(filterQuery)
} else if fileQuery.Op == "search" {
query.Where("teldrive.get_tsquery(?) @@ teldrive.get_tsvector(name)", fileQuery.Search)
setOrderFilter(query, &pagingParams, &sortingParams)
query.Order(getOrder(sortingParams))
}
var results []schemas.FileOut
query.Find(&results)
token := ""
if len(results) == pagingParams.PerPage {
lastItem := results[len(results)-1]
token = utils.GetField(&lastItem, utils.CamelToPascalCase(sortingParams.Sort))
token = base64.StdEncoding.EncodeToString([]byte(token))
}
res := &schemas.FileResponse{Results: results, NextPageToken: token}
return res, nil
}
func (fs *FileService) CheckIfPathExists(path *string) (bool, string) {
query := fs.Db.Model(&models.File{}).Select("id").Where("path = ?", path)
var results []schemas.FileOut
query.Find(&results)
if len(results) == 0 {
return false, "This directory doesn't exist."
}
return true, ""
}
func (fs *FileService) MakeDirectory(c *gin.Context) (*schemas.FileOut, *types.AppError) {
var payload schemas.MkDir
var files []models.File
if err := c.ShouldBindJSON(&payload); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
userId, _ := getUserAuth(c)
if err := fs.Db.Raw("select * from teldrive.create_directories(?, ?)", userId, payload.Path).Scan(&files).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to create directories"), Code: http.StatusInternalServerError}
}
file := mapper.MapFileToFileOut(files[0])
return &file, nil
}
func (fs *FileService) MoveFiles(c *gin.Context) (*schemas.Message, *types.AppError) {
var payload schemas.FileOperation
if err := c.ShouldBindJSON(&payload); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
var destination models.File
if pathExists, message := fs.CheckIfPathExists(&payload.Destination); !pathExists {
return nil, &types.AppError{Error: errors.New(message), Code: http.StatusBadRequest}
}
if err := fs.Db.Model(&models.File{}).Select("id").Where("path = ?", payload.Destination).First(&destination).Error; errors.Is(err, gorm.ErrRecordNotFound) {
return nil, &types.AppError{Error: errors.New("destination not found"), Code: http.StatusNotFound}
}
if err := fs.Db.Model(&models.File{}).Where("id IN ?", payload.Files).UpdateColumn("parent_id", destination.ID).Error; err != nil {
return nil, &types.AppError{Error: errors.New("move failed"), Code: http.StatusInternalServerError}
}
return &schemas.Message{Status: true, Message: "files moved"}, nil
}
func (fs *FileService) DeleteFiles(c *gin.Context) (*schemas.Message, *types.AppError) {
var payload schemas.FileOperation
if err := c.ShouldBindJSON(&payload); err != nil {
return nil, &types.AppError{Error: errors.New("invalid request payload"), Code: http.StatusBadRequest}
}
if err := fs.Db.Exec("call teldrive.delete_files($1)", payload.Files).Error; err != nil {
return nil, &types.AppError{Error: errors.New("failed to delete files"), Code: http.StatusInternalServerError}
}
return &schemas.Message{Status: true, Message: "files deleted"}, nil
}
func (fs *FileService) GetFileStream(c *gin.Context) {
w := c.Writer
r := c.Request
fileID := c.Param("fileID")
authHash := c.Query("hash")
if authHash == "" {
http.Error(w, "misssing hash", http.StatusBadRequest)
return
}
data, err := database.KV.Get(kv.Key("sessions", authHash))
if err != nil {
http.Error(w, "hash missing relogin", http.StatusBadRequest)
return
}
jwtUser := &types.JWTClaims{}
err = json.Unmarshal(data, jwtUser)
if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest)
return
}
file := &schemas.FileOutFull{}
key := kv.Key("files", fileID)
err = kv.GetValue(database.KV, key, file)
if err != nil {
file, err = fs.GetFileByID(c)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
kv.SetValue(database.KV, key, file)
}
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
if ifModifiedSinceHeader != "" {
ifModifiedSinceTime, err := time.Parse(http.TimeFormat, ifModifiedSinceHeader)
if err == nil && file.UpdatedAt.Before(ifModifiedSinceTime.Add(1*time.Second)) {
w.WriteHeader(http.StatusNotModified)
return
}
}
w.Header().Set("Accept-Ranges", "bytes")
var start, end int64
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
start = 0
end = file.Size - 1
w.WriteHeader(http.StatusOK)
} else {
ranges, err := range_parser.Parse(file.Size, r.Header.Get("Range"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
start = ranges[0].Start
end = ranges[0].End
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size))
w.WriteHeader(http.StatusPartialContent)
}
contentLength := end - start + 1
mimeType := file.MimeType
if mimeType == "" {
mimeType = "application/octet-stream"
}
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
w.Header().Set("E-Tag", md5.FromString(file.ID+strconv.FormatInt(file.Size, 10)))
w.Header().Set("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat))
disposition := "inline"
if c.Query("d") == "1" {
disposition = "attachment"
}
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, file.Name))
userID, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
tokens, err := GetBotsToken(userID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var client *telegram.Client
var token string
var channelUser string
if len(tokens) == 0 {
client, _ = tgc.UserLogin(jwtUser.TgSession)
channelUser = jwtUser.Subject
} else {
tgc.Workers.Set(tokens)
token = tgc.Workers.Next()
client, _ = tgc.BotLogin(token)
channelUser = strings.Split(token, ":")[0]
}
if r.Method != "HEAD" {
tgc.RunWithAuth(c, client, token, func(ctx context.Context) error {
parts, err := getParts(c, client, file, channelUser)
if err != nil {
return err
}
parts = rangedParts(parts, start, end)
r, _ := reader.NewLinearReader(c, client, parts)
io.CopyN(w, r, contentLength)
return nil
})
}
}
func setOrderFilter(query *gorm.DB, pagingParams *schemas.PaginationQuery, sortingParams *schemas.SortingQuery) *gorm.DB {
if pagingParams.NextPageToken != "" {
sortColumn := sortingParams.Sort
if sortColumn == "name" {
sortColumn = "name collate numeric"
} else {
sortColumn = utils.CamelToSnake(sortingParams.Sort)
}
tokenValue, err := base64.StdEncoding.DecodeString(pagingParams.NextPageToken)
if err == nil {
if sortingParams.Order == "asc" {
return query.Where(fmt.Sprintf("%s > ?", sortColumn), string(tokenValue))
} else {
return query.Where(fmt.Sprintf("%s < ?", sortColumn), string(tokenValue))
}
}
}
return query
}
func getOrder(sortingParams schemas.SortingQuery) string {
sortColumn := utils.CamelToSnake(sortingParams.Sort)
if sortingParams.Sort == "name" {
sortColumn = "name collate numeric"
}
return fmt.Sprintf("%s %s", sortColumn, strings.ToUpper(sortingParams.Order))
}