mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-03 21:14:28 +08:00
refactor: add UUID validation and improve file handling in copy and move operations
This commit is contained in:
parent
77e463b998
commit
74bae5fdc2
3 changed files with 62 additions and 46 deletions
5
internal/database/migrations/20250105180250_index.sql
Normal file
5
internal/database/migrations/20250105180250_index.sql
Normal file
|
@ -0,0 +1,5 @@
|
|||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
DROP INDEX IF EXISTS teldrive.idx_files_unique_file;
|
||||
|
||||
-- +goose StatementEnd
|
|
@ -13,8 +13,10 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/auth"
|
||||
"github.com/tgdrive/teldrive/internal/category"
|
||||
|
@ -68,6 +70,10 @@ func randInt64() (int64, error) {
|
|||
b := &buffer{Buf: buf[:]}
|
||||
return b.long()
|
||||
}
|
||||
func isUUID(str string) bool {
|
||||
_, err := uuid.Parse(str)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
type fullFileDB struct {
|
||||
models.File
|
||||
|
@ -182,15 +188,18 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
||||
var destRes []models.File
|
||||
|
||||
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
|
||||
Scan(&destRes).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
var parentId string
|
||||
if !isUUID(req.Destination) {
|
||||
var destRes []models.File
|
||||
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
|
||||
Scan(&destRes).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
parentId = destRes[0].Id
|
||||
} else {
|
||||
parentId = req.Destination
|
||||
}
|
||||
|
||||
dest := destRes[0]
|
||||
|
||||
dbFile := models.File{}
|
||||
|
||||
dbFile.Name = req.NewName.Or(file.Name)
|
||||
|
@ -201,7 +210,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
dbFile.UserID = userId
|
||||
dbFile.Status = "active"
|
||||
dbFile.ParentID = sql.NullString{
|
||||
String: dest.Id,
|
||||
String: parentId,
|
||||
Valid: true,
|
||||
}
|
||||
dbFile.ChannelID = &channelId
|
||||
|
@ -227,7 +236,6 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
|||
)
|
||||
|
||||
if fileIn.Path.Value != "" {
|
||||
path = strings.TrimSpace(fileIn.Path.Value)
|
||||
path = strings.ReplaceAll(path, "//", "/")
|
||||
if path != "/" {
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
@ -323,18 +331,10 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
|
|||
|
||||
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
if req.Source.Value == "" && len(req.Ids) == 0 {
|
||||
return &apiError{err: errors.New("source or ids is required"), code: 409}
|
||||
}
|
||||
if req.Source.Value != "" && len(req.Ids) == 0 {
|
||||
if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
} else if req.Source.Value == "" && len(req.Ids) > 0 {
|
||||
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -417,20 +417,23 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
|||
|
||||
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
if req.Source.Value == "" && len(req.Ids) == 0 {
|
||||
return &apiError{err: errors.New("source or ids is required"), code: 409}
|
||||
items := pgtype.Array[string]{
|
||||
Elements: req.Ids,
|
||||
Valid: true,
|
||||
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
|
||||
}
|
||||
if req.Source.Value != "" && len(req.Ids) > 0 {
|
||||
if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil {
|
||||
if !isUUID(req.Destination) {
|
||||
r, err := a.getFileFromPath(req.Destination, userId)
|
||||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
req.Destination = r.Id
|
||||
}
|
||||
if req.Source.Value == "" && len(req.Ids) == 0 {
|
||||
if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value,
|
||||
req.Destination, userId).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
|
||||
Update("parent_id", req.Destination).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
@ -469,10 +472,6 @@ func (a *apiService) FilesStream(ctx context.Context, params api.FilesStreamPara
|
|||
|
||||
func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, params api.FilesUpdateParams) (*api.File, error) {
|
||||
|
||||
var (
|
||||
files []models.File
|
||||
chain *gorm.DB
|
||||
)
|
||||
updateDb := models.File{}
|
||||
if req.Name.Value != "" {
|
||||
updateDb.Name = req.Name.Value
|
||||
|
@ -495,18 +494,17 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
|
|||
updateDb.UpdatedAt = req.UpdatedAt.Value
|
||||
}
|
||||
|
||||
chain = a.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", params.ID).Updates(updateDb)
|
||||
|
||||
if chain.Error != nil {
|
||||
return nil, &apiError{err: chain.Error}
|
||||
}
|
||||
if chain.RowsAffected == 0 {
|
||||
return nil, &apiError{err: errors.New("file not found"), code: 404}
|
||||
if err := a.db.Model(&models.File{}).Where("id = ?", params.ID).Updates(updateDb).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
||||
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
|
||||
|
||||
return mapper.ToFileOut(files[0], false), nil
|
||||
file := models.File{}
|
||||
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
return mapper.ToFileOut(file, false), nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
|
||||
|
@ -515,10 +513,8 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
|||
var file models.File
|
||||
|
||||
updatePayload := models.File{
|
||||
UpdatedAt: req.UpdatedAt,
|
||||
Size: utils.Ptr(req.Size),
|
||||
Size: utils.Ptr(req.Size),
|
||||
}
|
||||
|
||||
if req.ChannelId.Value == 0 {
|
||||
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
||||
if err != nil {
|
||||
|
@ -539,11 +535,21 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
|||
}
|
||||
updatePayload.Parts = datatypes.NewJSONSlice(parts)
|
||||
}
|
||||
if req.Name.Value != "" {
|
||||
updatePayload.Name = req.Name.Value
|
||||
}
|
||||
if req.ParentId.Value != "" {
|
||||
updatePayload.ParentID = sql.NullString{
|
||||
String: req.ParentId.Value,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
err := a.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil {
|
||||
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).
|
||||
Update("updated_at", req.UpdatedAt).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if req.UploadId.Value != "" {
|
||||
|
|
|
@ -97,7 +97,12 @@ func (afb *fileQueryBuilder) applyFileSpecificFilters(query *gorm.DB, filesQuery
|
|||
}
|
||||
|
||||
if filesQuery.ParentId.Value != "" {
|
||||
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
|
||||
if filesQuery.ParentId.Value == "nil" {
|
||||
query = query.Where("parent_id is NULL")
|
||||
} else {
|
||||
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if filesQuery.ParentId.Value == "" && filesQuery.Path.Value != "" && filesQuery.Query.Value == "" {
|
||||
|
|
Loading…
Add table
Reference in a new issue