refactor: Handle file updates in transaction (#292)

* feat: Add Stream Bots Offset flag to run command

* feat: Update Stream Bots Offset flag in run command

* update

* update

* update

* invalidate cache on update

* chore: Update cache key for file location deletion

* chore: Update cache key for file location retrieval

* chore: Refactor file service to optimize token processing

* chore: Optimize token processing in file service

* chore: Update file service to use PUT method for updating parts
This commit is contained in:
divyam234 2024-07-08 16:42:48 +05:30 committed by GitHub
parent 678de5bc63
commit 758f0f82ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 97 additions and 37 deletions

View file

@ -30,7 +30,7 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi
files.GET(":fileID/stream/:fileName", c.GetFileStream)
files.HEAD(":fileID/download/:fileName", c.GetFileDownload)
files.GET(":fileID/download/:fileName", c.GetFileDownload)
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
files.PUT(":fileID/parts", authmiddleware, c.UpdateParts)
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
files.POST("/move", authmiddleware, c.MoveFiles)
files.POST("/directories", authmiddleware, c.MakeDirectory)

View file

@ -98,6 +98,7 @@ func NewRun() *cobra.Command {
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset")
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")

View file

@ -51,6 +51,7 @@ type TGConfig struct {
Retention time.Duration
}
Stream struct {
BotsOffset int
MultiThreads int
Buffers int
ChunkTimeout time.Duration

View file

@ -46,10 +46,11 @@ type TestSuite struct {
func (suite *TestSuite) SetupTest() {
suite.config = &config.TGConfig{Stream: struct {
BotsOffset int
MultiThreads int
Buffers int
ChunkTimeout time.Duration
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
}{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
}
func (suite *TestSuite) TestFullRead() {

View file

@ -202,7 +202,7 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i
cache := cache.FromContext(ctx)
key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId)
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId)
err = cache.Get(key, location)
@ -221,12 +221,22 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i
if err != nil {
return nil, err
}
messages, _ := res.(*tg.MessagesChannelMessages)
item := messages.Messages[0].(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location = document.AsInputDocumentFileLocation()
cache.Set(key, location, 3600)
if len(messages.Messages) == 0 {
return nil, errors.New("no messages found")
}
switch item := messages.Messages[0].(type) {
case *tg.MessageEmpty:
return nil, errors.New("no messages found")
case *tg.Message:
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
location = document.AsInputDocumentFileLocation()
cache.Set(key, location, 1800)
}
}
return location, nil
}

View file

@ -147,9 +147,15 @@ func (fc *Controller) DeleteFiles(c *gin.Context) {
c.JSON(http.StatusOK, res)
}
func (fc *Controller) DeleteFileParts(c *gin.Context) {
func (fc *Controller) UpdateParts(c *gin.Context) {
res, err := fc.FileService.DeleteFileParts(c, c.Param("fileID"))
var payload schemas.PartUpdate
if err := c.ShouldBindJSON(&payload); err != nil {
httputil.NewError(c, http.StatusBadRequest, err)
return
}
res, err := fc.FileService.UpdateParts(c, c.Param("fileID"), &payload)
if err != nil {
httputil.NewError(c, err.Code, err.Error)
return

View file

@ -65,6 +65,7 @@ type FileUpdate struct {
Starred *bool `json:"starred,omitempty"`
ParentID string `json:"parentId,omitempty"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"`
Parts []Part `json:"parts,omitempty"`
Size *int64 `json:"size,omitempty"`
}
@ -82,6 +83,12 @@ type DeleteOperation struct {
Files []string `json:"files,omitempty"`
Source string `json:"source,omitempty"`
}
type PartUpdate struct {
Parts []Part `json:"parts"`
UploadId string `json:"uploadId"`
UpdatedAt time.Time `json:"updatedAt" binding:"required"`
Size int64 `json:"size"`
}
type DirMove struct {
Source string `json:"source" binding:"required"`

View file

@ -19,7 +19,7 @@ func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull,
cache := cache.FromContext(ctx)
parts := []types.Part{}
key := fmt.Sprintf("messages:%s:%s", file.Id, userID)
key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID)
err := cache.Get(key, &parts)

View file

@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"mime"
@ -38,6 +39,10 @@ import (
"gorm.io/gorm/clause"
)
var (
ErrorStreamAbandoned = errors.New("stream abandoned")
)
type buffer struct {
Buf []byte
}
@ -48,6 +53,7 @@ func (b *buffer) long() (int64, error) {
return 0, err
}
return int64(v), nil
}
func (b *buffer) uint64() (uint64, error) {
const size = 8
@ -149,6 +155,7 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
ParentID: update.ParentID,
UpdatedAt: update.UpdatedAt,
Size: update.Size,
CreatedAt: update.CreatedAt,
}
if update.Starred != nil {
@ -160,8 +167,6 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
}
chain = fs.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", id).Updates(updateDb)
cache.Delete(fmt.Sprintf("files:%s", id))
if chain.Error != nil {
return nil, &types.AppError{Error: chain.Error}
}
@ -169,6 +174,15 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
}
cache.Delete(fmt.Sprintf("files:%s", id))
if len(update.Parts) > 0 {
cache.Delete(fmt.Sprintf("files:messages:%s:%d", id, userId))
for _, part := range files[0].Parts {
cache.Delete(fmt.Sprintf("files:location:%d:%s:%d", userId, id, part.ID))
}
}
return mapper.ToFileOut(files[0]), nil
}
@ -356,32 +370,53 @@ func (fs *FileService) DeleteFiles(userId int64, payload *schemas.DeleteOperatio
return &schemas.Message{Message: "files deleted"}, nil
}
func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Message, *types.AppError) {
func (fs *FileService) UpdateParts(c *gin.Context, id string, payload *schemas.PartUpdate) (*schemas.Message, *types.AppError) {
var file models.File
if err := fs.db.Where("id = ?", id).First(&file).Error; err != nil {
if database.IsRecordNotFoundErr(err) {
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
updatePayload := models.File{
UpdatedAt: payload.UpdatedAt,
Size: utils.Int64Pointer(payload.Size),
}
if len(payload.Parts) > 0 {
updatePayload.Parts = datatypes.NewJSONSlice(payload.Parts)
}
err := fs.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("id = ?", id).First(&file).Error; err != nil {
return err
}
return nil, &types.AppError{Error: err}
}
_, session := auth.GetUser(c)
if err := tx.Model(models.File{}).Where("id = ?", id).Updates(updatePayload).Error; err != nil {
return err
}
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
if payload.UploadId != "" {
if err := tx.Where("upload_id = ?", payload.UploadId).Delete(&models.Upload{}).Error; err != nil {
return err
}
}
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
err := tgc.DeleteMessages(c, client, *file.ChannelID, ids)
return nil
})
if err != nil {
return nil, &types.AppError{Error: err}
}
return &schemas.Message{Message: "file parts deleted"}, nil
if len(file.Parts) > 0 && file.ChannelID != nil {
_, session := auth.GetUser(c)
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
tgc.DeleteMessages(c, client, *file.ChannelID, ids)
}
return &schemas.Message{Message: "file updated"}, nil
}
func (fs *FileService) MoveDirectory(userId int64, payload *schemas.DirMove) (*schemas.Message, *types.AppError) {
@ -667,7 +702,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
client, err = fs.worker.UserWorker(session.Session, session.UserId)
if err != nil {
logger.Error("file stream", zap.Error(err))
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -675,13 +710,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
multiThreads = 0
} else {
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
fs.worker.Set(tokens[:limit], file.ChannelID)
offset := fs.cnf.TG.Stream.BotsOffset - 1
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset)
fs.worker.Set(tokens[offset:limit], file.ChannelID)
client, _, err = fs.worker.Next(file.ChannelID)
if err != nil {
logger.Error("file stream", zap.Error(err))
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -690,7 +724,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
if r.Method != "HEAD" {
parts, err := getParts(c, client.Tg.API(), file, channelUser)
if err != nil {
logger.Error("file stream", err)
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -705,7 +739,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
}
if err != nil {
logger.Error("file stream", err)
logger.Error(ErrorStreamAbandoned, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}