From 758f0f82ea96eb71d5e66edf037ff262875b1511 Mon Sep 17 00:00:00 2001 From: divyam234 <47589864+divyam234@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:42:48 +0530 Subject: [PATCH] refactor: Handle file updates in transaction (#292) * feat: Add Stream Bots Offset flag to run command * feat: Update Stream Bots Offset flag in run command * update * update * update * invalidate cache on update * chore: Update cache key for file location deletion * chore: Update cache key for file location retrieval * chore: Refactor file service to optimize token processing * chore: Optimize token processing in file service * chore: Update file service to use PUT method for updating parts --- api/router.go | 2 +- cmd/run.go | 1 + internal/config/config.go | 1 + internal/reader/tg_multi_reader_test.go | 3 +- internal/tgc/helpers.go | 22 +++++-- pkg/controller/file.go | 10 ++- pkg/schemas/file.go | 7 ++ pkg/services/common.go | 2 +- pkg/services/file.go | 86 +++++++++++++++++-------- 9 files changed, 97 insertions(+), 37 deletions(-) diff --git a/api/router.go b/api/router.go index eabc234..d8fbd86 100644 --- a/api/router.go +++ b/api/router.go @@ -30,7 +30,7 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi files.GET(":fileID/stream/:fileName", c.GetFileStream) files.HEAD(":fileID/download/:fileName", c.GetFileDownload) files.GET(":fileID/download/:fileName", c.GetFileDownload) - files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts) + files.PUT(":fileID/parts", authmiddleware, c.UpdateParts) files.GET("/category/stats", authmiddleware, c.GetCategoryStats) files.POST("/move", authmiddleware, c.MoveFiles) files.POST("/directories", authmiddleware, c.MakeDirectory) diff --git a/cmd/run.go b/cmd/run.go index 6239505..42281b4 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -98,6 +98,7 @@ func NewRun() *cobra.Command { duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration") runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads") runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers") + runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset") duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout") duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots") duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots") diff --git a/internal/config/config.go b/internal/config/config.go index d008fe4..177467a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,6 +51,7 @@ type TGConfig struct { Retention time.Duration } Stream struct { + BotsOffset int MultiThreads int Buffers int ChunkTimeout time.Duration diff --git a/internal/reader/tg_multi_reader_test.go b/internal/reader/tg_multi_reader_test.go index 2ff7305..b0b8274 100644 --- a/internal/reader/tg_multi_reader_test.go +++ b/internal/reader/tg_multi_reader_test.go @@ -46,10 +46,11 @@ type TestSuite struct { func (suite *TestSuite) SetupTest() { suite.config = &config.TGConfig{Stream: struct { + BotsOffset int MultiThreads int Buffers int ChunkTimeout time.Duration - }{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}} + }{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}} } func (suite *TestSuite) TestFullRead() { diff --git a/internal/tgc/helpers.go b/internal/tgc/helpers.go index 09059aa..01e87bb 100644 --- a/internal/tgc/helpers.go +++ b/internal/tgc/helpers.go @@ -202,7 +202,7 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i cache := cache.FromContext(ctx) - key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId) + key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId) err = cache.Get(key, location) @@ -221,12 +221,22 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i if err != nil { return nil, err } + messages, _ := res.(*tg.MessagesChannelMessages) - item := messages.Messages[0].(*tg.Message) - media := item.Media.(*tg.MessageMediaDocument) - document := media.Document.(*tg.Document) - location = document.AsInputDocumentFileLocation() - cache.Set(key, location, 3600) + + if len(messages.Messages) == 0 { + return nil, errors.New("no messages found") + } + + switch item := messages.Messages[0].(type) { + case *tg.MessageEmpty: + return nil, errors.New("no messages found") + case *tg.Message: + media := item.Media.(*tg.MessageMediaDocument) + document := media.Document.(*tg.Document) + location = document.AsInputDocumentFileLocation() + cache.Set(key, location, 1800) + } } return location, nil } diff --git a/pkg/controller/file.go b/pkg/controller/file.go index 9110a87..2a333f3 100644 --- a/pkg/controller/file.go +++ b/pkg/controller/file.go @@ -147,9 +147,15 @@ func (fc *Controller) DeleteFiles(c *gin.Context) { c.JSON(http.StatusOK, res) } -func (fc *Controller) DeleteFileParts(c *gin.Context) { +func (fc *Controller) UpdateParts(c *gin.Context) { - res, err := fc.FileService.DeleteFileParts(c, c.Param("fileID")) + var payload schemas.PartUpdate + if err := c.ShouldBindJSON(&payload); err != nil { + httputil.NewError(c, http.StatusBadRequest, err) + return + } + + res, err := fc.FileService.UpdateParts(c, c.Param("fileID"), &payload) if err != nil { httputil.NewError(c, err.Code, err.Error) return diff --git a/pkg/schemas/file.go b/pkg/schemas/file.go index 22c14c7..feaae1e 100644 --- a/pkg/schemas/file.go +++ b/pkg/schemas/file.go @@ -65,6 +65,7 @@ type FileUpdate struct { Starred *bool `json:"starred,omitempty"` ParentID string `json:"parentId,omitempty"` UpdatedAt time.Time `json:"updatedAt,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` Parts []Part `json:"parts,omitempty"` Size *int64 `json:"size,omitempty"` } @@ -82,6 +83,12 @@ type DeleteOperation struct { Files []string `json:"files,omitempty"` Source string `json:"source,omitempty"` } +type PartUpdate struct { + Parts []Part `json:"parts"` + UploadId string `json:"uploadId"` + UpdatedAt time.Time `json:"updatedAt" binding:"required"` + Size int64 `json:"size"` +} type DirMove struct { Source string `json:"source" binding:"required"` diff --git a/pkg/services/common.go b/pkg/services/common.go index 05b21c6..d729c13 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -19,7 +19,7 @@ func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, cache := cache.FromContext(ctx) parts := []types.Part{} - key := fmt.Sprintf("messages:%s:%s", file.Id, userID) + key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID) err := cache.Get(key, &parts) diff --git a/pkg/services/file.go b/pkg/services/file.go index d8f2382..45cd7e2 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/binary" + "errors" "fmt" "io" "mime" @@ -38,6 +39,10 @@ import ( "gorm.io/gorm/clause" ) +var ( + ErrorStreamAbandoned = errors.New("stream abandoned") +) + type buffer struct { Buf []byte } @@ -48,6 +53,7 @@ func (b *buffer) long() (int64, error) { return 0, err } return int64(v), nil + } func (b *buffer) uint64() (uint64, error) { const size = 8 @@ -149,6 +155,7 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU ParentID: update.ParentID, UpdatedAt: update.UpdatedAt, Size: update.Size, + CreatedAt: update.CreatedAt, } if update.Starred != nil { @@ -160,8 +167,6 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU } chain = fs.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", id).Updates(updateDb) - cache.Delete(fmt.Sprintf("files:%s", id)) - if chain.Error != nil { return nil, &types.AppError{Error: chain.Error} } @@ -169,6 +174,15 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound} } + cache.Delete(fmt.Sprintf("files:%s", id)) + + if len(update.Parts) > 0 { + cache.Delete(fmt.Sprintf("files:messages:%s:%d", id, userId)) + for _, part := range files[0].Parts { + cache.Delete(fmt.Sprintf("files:location:%d:%s:%d", userId, id, part.ID)) + } + } + return mapper.ToFileOut(files[0]), nil } @@ -356,32 +370,53 @@ func (fs *FileService) DeleteFiles(userId int64, payload *schemas.DeleteOperatio return &schemas.Message{Message: "files deleted"}, nil } -func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Message, *types.AppError) { +func (fs *FileService) UpdateParts(c *gin.Context, id string, payload *schemas.PartUpdate) (*schemas.Message, *types.AppError) { + var file models.File - if err := fs.db.Where("id = ?", id).First(&file).Error; err != nil { - if database.IsRecordNotFoundErr(err) { - return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound} + + updatePayload := models.File{ + UpdatedAt: payload.UpdatedAt, + Size: utils.Int64Pointer(payload.Size), + } + + if len(payload.Parts) > 0 { + updatePayload.Parts = datatypes.NewJSONSlice(payload.Parts) + } + + err := fs.db.Transaction(func(tx *gorm.DB) error { + + if err := tx.Where("id = ?", id).First(&file).Error; err != nil { + return err } - return nil, &types.AppError{Error: err} - } - _, session := auth.GetUser(c) + if err := tx.Model(models.File{}).Where("id = ?", id).Updates(updatePayload).Error; err != nil { + return err + } - client, _ := tgc.AuthClient(c, &fs.cnf.TG, session) + if payload.UploadId != "" { + if err := tx.Where("upload_id = ?", payload.UploadId).Delete(&models.Upload{}).Error; err != nil { + return err + } + } - ids := []int{} - - for _, part := range file.Parts { - ids = append(ids, int(part.ID)) - } - - err := tgc.DeleteMessages(c, client, *file.ChannelID, ids) + return nil + }) if err != nil { return nil, &types.AppError{Error: err} } - return &schemas.Message{Message: "file parts deleted"}, nil + if len(file.Parts) > 0 && file.ChannelID != nil { + _, session := auth.GetUser(c) + ids := []int{} + for _, part := range file.Parts { + ids = append(ids, int(part.ID)) + } + client, _ := tgc.AuthClient(c, &fs.cnf.TG, session) + tgc.DeleteMessages(c, client, *file.ChannelID, ids) + } + + return &schemas.Message{Message: "file updated"}, nil } func (fs *FileService) MoveDirectory(userId int64, payload *schemas.DirMove) (*schemas.Message, *types.AppError) { @@ -667,7 +702,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 { client, err = fs.worker.UserWorker(session.Session, session.UserId) if err != nil { - logger.Error("file stream", zap.Error(err)) + logger.Error(ErrorStreamAbandoned, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -675,13 +710,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { multiThreads = 0 } else { - - limit := min(len(tokens), fs.cnf.TG.BgBotsLimit) - - fs.worker.Set(tokens[:limit], file.ChannelID) + offset := fs.cnf.TG.Stream.BotsOffset - 1 + limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset) + fs.worker.Set(tokens[offset:limit], file.ChannelID) client, _, err = fs.worker.Next(file.ChannelID) if err != nil { - logger.Error("file stream", zap.Error(err)) + logger.Error(ErrorStreamAbandoned, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -690,7 +724,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { if r.Method != "HEAD" { parts, err := getParts(c, client.Tg.API(), file, channelUser) if err != nil { - logger.Error("file stream", err) + logger.Error(ErrorStreamAbandoned, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -705,7 +739,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { } if err != nil { - logger.Error("file stream", err) + logger.Error(ErrorStreamAbandoned, err) http.Error(w, err.Error(), http.StatusInternalServerError) return }