mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-05 22:14:30 +08:00
refactor: Handle file updates in transaction (#292)
* feat: Add Stream Bots Offset flag to run command * feat: Update Stream Bots Offset flag in run command * update * update * update * invalidate cache on update * chore: Update cache key for file location deletion * chore: Update cache key for file location retrieval * chore: Refactor file service to optimize token processing * chore: Optimize token processing in file service * chore: Update file service to use PUT method for updating parts
This commit is contained in:
parent
678de5bc63
commit
758f0f82ea
9 changed files with 97 additions and 37 deletions
|
@ -30,7 +30,7 @@ func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gi
|
|||
files.GET(":fileID/stream/:fileName", c.GetFileStream)
|
||||
files.HEAD(":fileID/download/:fileName", c.GetFileDownload)
|
||||
files.GET(":fileID/download/:fileName", c.GetFileDownload)
|
||||
files.DELETE(":fileID/parts", authmiddleware, c.DeleteFileParts)
|
||||
files.PUT(":fileID/parts", authmiddleware, c.UpdateParts)
|
||||
files.GET("/category/stats", authmiddleware, c.GetCategoryStats)
|
||||
files.POST("/move", authmiddleware, c.MoveFiles)
|
||||
files.POST("/directories", authmiddleware, c.MakeDirectory)
|
||||
|
|
|
@ -98,6 +98,7 @@ func NewRun() *cobra.Command {
|
|||
duration.DurationVar(runCmd.Flags(), &config.TG.Uploads.Retention, "tg-uploads-retention", (24*7)*time.Hour, "Uploads retention duration")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.MultiThreads, "tg-stream-multi-threads", 0, "Stream multi-threads")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.Buffers, "tg-stream-buffers", 8, "No of Stream buffers")
|
||||
runCmd.Flags().IntVar(&config.TG.Stream.BotsOffset, "tg-stream-bots-offset", 1, "Stream Bots Offset")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.Stream.ChunkTimeout, "tg-stream-chunk-timeout", 30*time.Second, "Chunk Fetch Timeout")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsTimeout, "tg-bg-bots-timeout", 30*time.Minute, "Stop Timeout for Idle background bots")
|
||||
duration.DurationVar(runCmd.Flags(), &config.TG.BgBotsCheckInterval, "tg-bg-bots-check-interval", 5*time.Minute, "Interval for checking Idle background bots")
|
||||
|
|
|
@ -51,6 +51,7 @@ type TGConfig struct {
|
|||
Retention time.Duration
|
||||
}
|
||||
Stream struct {
|
||||
BotsOffset int
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
ChunkTimeout time.Duration
|
||||
|
|
|
@ -46,10 +46,11 @@ type TestSuite struct {
|
|||
|
||||
func (suite *TestSuite) SetupTest() {
|
||||
suite.config = &config.TGConfig{Stream: struct {
|
||||
BotsOffset int
|
||||
MultiThreads int
|
||||
Buffers int
|
||||
ChunkTimeout time.Duration
|
||||
}{MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||
}{BotsOffset: 1, MultiThreads: 8, Buffers: 10, ChunkTimeout: 1 * time.Second}}
|
||||
}
|
||||
|
||||
func (suite *TestSuite) TestFullRead() {
|
||||
|
|
|
@ -202,7 +202,7 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i
|
|||
|
||||
cache := cache.FromContext(ctx)
|
||||
|
||||
key := fmt.Sprintf("location:%s:%s:%d", client.UserId, fileId, partId)
|
||||
key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId)
|
||||
|
||||
err = cache.Get(key, location)
|
||||
|
||||
|
@ -221,12 +221,22 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages, _ := res.(*tg.MessagesChannelMessages)
|
||||
item := messages.Messages[0].(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location = document.AsInputDocumentFileLocation()
|
||||
cache.Set(key, location, 3600)
|
||||
|
||||
if len(messages.Messages) == 0 {
|
||||
return nil, errors.New("no messages found")
|
||||
}
|
||||
|
||||
switch item := messages.Messages[0].(type) {
|
||||
case *tg.MessageEmpty:
|
||||
return nil, errors.New("no messages found")
|
||||
case *tg.Message:
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
location = document.AsInputDocumentFileLocation()
|
||||
cache.Set(key, location, 1800)
|
||||
}
|
||||
}
|
||||
return location, nil
|
||||
}
|
||||
|
|
|
@ -147,9 +147,15 @@ func (fc *Controller) DeleteFiles(c *gin.Context) {
|
|||
c.JSON(http.StatusOK, res)
|
||||
}
|
||||
|
||||
func (fc *Controller) DeleteFileParts(c *gin.Context) {
|
||||
func (fc *Controller) UpdateParts(c *gin.Context) {
|
||||
|
||||
res, err := fc.FileService.DeleteFileParts(c, c.Param("fileID"))
|
||||
var payload schemas.PartUpdate
|
||||
if err := c.ShouldBindJSON(&payload); err != nil {
|
||||
httputil.NewError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
res, err := fc.FileService.UpdateParts(c, c.Param("fileID"), &payload)
|
||||
if err != nil {
|
||||
httputil.NewError(c, err.Code, err.Error)
|
||||
return
|
||||
|
|
|
@ -65,6 +65,7 @@ type FileUpdate struct {
|
|||
Starred *bool `json:"starred,omitempty"`
|
||||
ParentID string `json:"parentId,omitempty"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
Parts []Part `json:"parts,omitempty"`
|
||||
Size *int64 `json:"size,omitempty"`
|
||||
}
|
||||
|
@ -82,6 +83,12 @@ type DeleteOperation struct {
|
|||
Files []string `json:"files,omitempty"`
|
||||
Source string `json:"source,omitempty"`
|
||||
}
|
||||
type PartUpdate struct {
|
||||
Parts []Part `json:"parts"`
|
||||
UploadId string `json:"uploadId"`
|
||||
UpdatedAt time.Time `json:"updatedAt" binding:"required"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
type DirMove struct {
|
||||
Source string `json:"source" binding:"required"`
|
||||
|
|
|
@ -19,7 +19,7 @@ func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull,
|
|||
cache := cache.FromContext(ctx)
|
||||
parts := []types.Part{}
|
||||
|
||||
key := fmt.Sprintf("messages:%s:%s", file.Id, userID)
|
||||
key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID)
|
||||
|
||||
err := cache.Get(key, &parts)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
|
@ -38,6 +39,10 @@ import (
|
|||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorStreamAbandoned = errors.New("stream abandoned")
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
Buf []byte
|
||||
}
|
||||
|
@ -48,6 +53,7 @@ func (b *buffer) long() (int64, error) {
|
|||
return 0, err
|
||||
}
|
||||
return int64(v), nil
|
||||
|
||||
}
|
||||
func (b *buffer) uint64() (uint64, error) {
|
||||
const size = 8
|
||||
|
@ -149,6 +155,7 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
|
|||
ParentID: update.ParentID,
|
||||
UpdatedAt: update.UpdatedAt,
|
||||
Size: update.Size,
|
||||
CreatedAt: update.CreatedAt,
|
||||
}
|
||||
|
||||
if update.Starred != nil {
|
||||
|
@ -160,8 +167,6 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
|
|||
}
|
||||
chain = fs.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", id).Updates(updateDb)
|
||||
|
||||
cache.Delete(fmt.Sprintf("files:%s", id))
|
||||
|
||||
if chain.Error != nil {
|
||||
return nil, &types.AppError{Error: chain.Error}
|
||||
}
|
||||
|
@ -169,6 +174,15 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU
|
|||
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
|
||||
}
|
||||
|
||||
cache.Delete(fmt.Sprintf("files:%s", id))
|
||||
|
||||
if len(update.Parts) > 0 {
|
||||
cache.Delete(fmt.Sprintf("files:messages:%s:%d", id, userId))
|
||||
for _, part := range files[0].Parts {
|
||||
cache.Delete(fmt.Sprintf("files:location:%d:%s:%d", userId, id, part.ID))
|
||||
}
|
||||
}
|
||||
|
||||
return mapper.ToFileOut(files[0]), nil
|
||||
|
||||
}
|
||||
|
@ -356,32 +370,53 @@ func (fs *FileService) DeleteFiles(userId int64, payload *schemas.DeleteOperatio
|
|||
return &schemas.Message{Message: "files deleted"}, nil
|
||||
}
|
||||
|
||||
func (fs *FileService) DeleteFileParts(c *gin.Context, id string) (*schemas.Message, *types.AppError) {
|
||||
func (fs *FileService) UpdateParts(c *gin.Context, id string, payload *schemas.PartUpdate) (*schemas.Message, *types.AppError) {
|
||||
|
||||
var file models.File
|
||||
if err := fs.db.Where("id = ?", id).First(&file).Error; err != nil {
|
||||
if database.IsRecordNotFoundErr(err) {
|
||||
return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound}
|
||||
|
||||
updatePayload := models.File{
|
||||
UpdatedAt: payload.UpdatedAt,
|
||||
Size: utils.Int64Pointer(payload.Size),
|
||||
}
|
||||
|
||||
if len(payload.Parts) > 0 {
|
||||
updatePayload.Parts = datatypes.NewJSONSlice(payload.Parts)
|
||||
}
|
||||
|
||||
err := fs.db.Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
if err := tx.Where("id = ?", id).First(&file).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil, &types.AppError{Error: err}
|
||||
}
|
||||
|
||||
_, session := auth.GetUser(c)
|
||||
if err := tx.Model(models.File{}).Where("id = ?", id).Updates(updatePayload).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||
if payload.UploadId != "" {
|
||||
if err := tx.Where("upload_id = ?", payload.UploadId).Delete(&models.Upload{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ids := []int{}
|
||||
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
|
||||
err := tgc.DeleteMessages(c, client, *file.ChannelID, ids)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, &types.AppError{Error: err}
|
||||
}
|
||||
|
||||
return &schemas.Message{Message: "file parts deleted"}, nil
|
||||
if len(file.Parts) > 0 && file.ChannelID != nil {
|
||||
_, session := auth.GetUser(c)
|
||||
ids := []int{}
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
client, _ := tgc.AuthClient(c, &fs.cnf.TG, session)
|
||||
tgc.DeleteMessages(c, client, *file.ChannelID, ids)
|
||||
}
|
||||
|
||||
return &schemas.Message{Message: "file updated"}, nil
|
||||
}
|
||||
|
||||
func (fs *FileService) MoveDirectory(userId int64, payload *schemas.DirMove) (*schemas.Message, *types.AppError) {
|
||||
|
@ -667,7 +702,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
if fs.cnf.TG.DisableStreamBots || len(tokens) == 0 {
|
||||
client, err = fs.worker.UserWorker(session.Session, session.UserId)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -675,13 +710,12 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
multiThreads = 0
|
||||
|
||||
} else {
|
||||
|
||||
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit)
|
||||
|
||||
fs.worker.Set(tokens[:limit], file.ChannelID)
|
||||
offset := fs.cnf.TG.Stream.BotsOffset - 1
|
||||
limit := min(len(tokens), fs.cnf.TG.BgBotsLimit+offset)
|
||||
fs.worker.Set(tokens[offset:limit], file.ChannelID)
|
||||
client, _, err = fs.worker.Next(file.ChannelID)
|
||||
if err != nil {
|
||||
logger.Error("file stream", zap.Error(err))
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -690,7 +724,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
if r.Method != "HEAD" {
|
||||
parts, err := getParts(c, client.Tg.API(), file, channelUser)
|
||||
if err != nil {
|
||||
logger.Error("file stream", err)
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -705,7 +739,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error("file stream", err)
|
||||
logger.Error(ErrorStreamAbandoned, err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue