fix: cache invalidation in file update

This commit is contained in:
Bhunter 2025-01-08 11:27:52 +01:00
parent 126a2c4af3
commit ffa9a27bb0

View file

@ -123,7 +123,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
return nil, &apiError{err: errors.New("file not found"), code: 404} return nil, &apiError{err: errors.New("file not found"), code: 404}
} }
file := mapper.ToFileOut(res[0], true) file := res[0]
newIds := []api.Part{} newIds := []api.Part{}
@ -138,7 +138,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
for _, part := range file.Parts { for _, part := range file.Parts {
ids = append(ids, int(part.ID)) ids = append(ids, int(part.ID))
} }
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value) messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID)
if err != nil { if err != nil {
return err return err
@ -179,7 +179,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
} }
} }
newIds = append(newIds, api.Part{ID: msg.ID, Salt: file.Parts[i].Salt}) p := api.Part{ID: msg.ID}
if file.Parts[i].Salt.Value != "" {
p.Salt = file.Parts[i].Salt
}
newIds = append(newIds, p)
} }
return nil return nil
@ -189,6 +193,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
return nil, &apiError{err: err} return nil, &apiError{err: err}
} }
if len(newIds) != len(file.Parts) {
return nil, &apiError{err: errors.New("failed to copy all file parts")}
}
var parentId string var parentId string
if !isUUID(req.Destination) { if !isUUID(req.Destination) {
var destRes []models.File var destRes []models.File
@ -204,10 +212,12 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
dbFile := models.File{} dbFile := models.File{}
dbFile.Name = req.NewName.Or(file.Name) dbFile.Name = req.NewName.Or(file.Name)
dbFile.Size = utils.Ptr(file.Size.Value) dbFile.Size = file.Size
dbFile.Type = string(file.Type) dbFile.Type = string(file.Type)
dbFile.MimeType = file.MimeType.Or(defaultContentType) dbFile.MimeType = file.MimeType
dbFile.Parts = datatypes.NewJSONSlice(newIds) if len(newIds) > 0 {
dbFile.Parts = datatypes.NewJSONSlice(newIds)
}
dbFile.UserID = userId dbFile.UserID = userId
dbFile.Status = "active" dbFile.Status = "active"
dbFile.ParentID = sql.NullString{ dbFile.ParentID = sql.NullString{
@ -215,8 +225,15 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
Valid: true, Valid: true,
} }
dbFile.ChannelID = &channelId dbFile.ChannelID = &channelId
dbFile.Encrypted = file.Encrypted.Value dbFile.Encrypted = file.Encrypted
dbFile.Category = string(file.Category.Value) dbFile.Category = string(file.Category)
if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
dbFile.UpdatedAt = req.UpdatedAt.Value
dbFile.CreatedAt = req.UpdatedAt.Value
} else {
dbFile.UpdatedAt = time.Now().UTC()
dbFile.CreatedAt = time.Now().UTC()
}
if err := a.db.Create(&dbFile).Error; err != nil { if err := a.db.Create(&dbFile).Error; err != nil {
return nil, &apiError{err: err} return nil, &apiError{err: err}
@ -236,8 +253,12 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
channelId int64 channelId int64
) )
if fileIn.Path.Value == "" && fileIn.ParentId.Value == "" {
return nil, &apiError{err: errors.New("parent id or path is required"), code: 409}
}
if fileIn.Path.Value != "" { if fileIn.Path.Value != "" {
path = strings.ReplaceAll(path, "//", "/") path = strings.ReplaceAll(fileIn.Path.Value, "//", "/")
if path != "/" { if path != "/" {
path = strings.TrimSuffix(path, "/") path = strings.TrimSuffix(path, "/")
} }
@ -258,8 +279,6 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
Valid: true, Valid: true,
} }
} else {
return nil, &apiError{err: errors.New("parent id or path is required"), code: 409}
} }
if fileIn.Type == "folder" { if fileIn.Type == "folder" {
@ -295,6 +314,13 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
fileDB.UserID = userId fileDB.UserID = userId
fileDB.Status = "active" fileDB.Status = "active"
fileDB.Encrypted = fileIn.Encrypted.Value fileDB.Encrypted = fileIn.Encrypted.Value
if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() {
fileDB.UpdatedAt = fileIn.UpdatedAt.Value
fileDB.CreatedAt = fileIn.UpdatedAt.Value
} else {
fileDB.UpdatedAt = time.Now().UTC()
fileDB.CreatedAt = time.Now().UTC()
}
if err := a.db.Create(&fileDB).Error; err != nil { if err := a.db.Create(&fileDB).Error; err != nil {
if database.IsKeyConflictErr(err) { if database.IsKeyConflictErr(err) {
return nil, &apiError{err: errors.New("file already exists"), code: 409} return nil, &apiError{err: errors.New("file already exists"), code: 409}
@ -491,9 +517,9 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
if req.Size.Value != 0 { if req.Size.Value != 0 {
updateDb.Size = utils.Ptr(req.Size.Value) updateDb.Size = utils.Ptr(req.Size.Value)
} }
if req.UpdatedAt.IsSet() { if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
updateDb.UpdatedAt = req.UpdatedAt.Value updateDb.UpdatedAt = req.UpdatedAt.Value
} else { } else if !req.UpdatedAt.IsSet() && params.Skiputs.Value == "0" {
updateDb.UpdatedAt = time.Now().UTC() updateDb.UpdatedAt = time.Now().UTC()
} }
@ -578,9 +604,9 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
} }
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
tgc.DeleteMessages(ctx, client, *file.ChannelID, ids) tgc.DeleteMessages(ctx, client, *file.ChannelID, ids)
keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s:%d", params.ID, userId)} keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)}
for _, part := range file.Parts { for _, part := range file.Parts {
keys = append(keys, fmt.Sprintf("files:location:%d:%s:%d", userId, params.ID, part.ID)) keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID))
} }
a.cache.Delete(keys...) a.cache.Delete(keys...)