refactor: simplify condition checks and improve upload handling

This commit is contained in:
divyam234 2025-01-01 18:21:45 +05:30
parent a296b368de
commit 77e463b998
12 changed files with 105 additions and 110 deletions

View file

@ -2,6 +2,10 @@ version: 2
project_name: teldrive
env:
- GO111MODULE=on
before:
hooks:
- go generate ./...
builds:
- env:

View file

@ -17,7 +17,7 @@ GOARCH ?= $(shell go env GOARCH)
VERSION:= $(GIT_TAG)
BINARY_EXTENSION :=
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version generate
.PHONY: all build run clean frontend backend run sync-ui retag patch-version minor-version gen
all: build

View file

@ -3,13 +3,8 @@ package chizap
import (
"context"
"net"
"net/http"
"net/http/httputil"
"os"
"regexp"
"runtime/debug"
"strings"
"time"
"github.com/go-chi/chi/v5/middleware"
@ -115,58 +110,3 @@ func ChizapWithConfig(logger ZapLogger, conf *Config) func(next http.Handler) ht
})
}
}
func defaultHandleRecovery(w http.ResponseWriter, r *http.Request, err interface{}) {
w.WriteHeader(http.StatusInternalServerError)
}
func RecoveryWithZap(logger ZapLogger, stack bool) func(next http.Handler) http.Handler {
return CustomRecoveryWithZap(logger, stack, defaultHandleRecovery)
}
func CustomRecoveryWithZap(logger ZapLogger, stack bool, recovery func(w http.ResponseWriter, r *http.Request, err interface{})) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") ||
strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(r, false)
if brokenPipe {
logger.Error(r.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
http.Error(w, "connection broken", http.StatusInternalServerError)
return
}
if stack {
logger.Error("[Recovery from panic]",
zap.Time("time", time.Now()),
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
logger.Error("[Recovery from panic]",
zap.Time("time", time.Now()),
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
recovery(w, r, err)
}
}()
next.ServeHTTP(w, r)
})
}
}

View file

@ -28,13 +28,21 @@ func SPAHandler(filesystem fs.FS) http.HandlerFunc {
logging.DefaultLogger().Fatal(err.Error())
}
return func(w http.ResponseWriter, r *http.Request) {
f, err := spaFS.Open(strings.TrimPrefix(path.Clean(r.URL.Path), "/"))
filePath := strings.TrimPrefix(path.Clean(r.URL.Path), "/")
f, err := spaFS.Open(filePath)
if err == nil {
defer f.Close()
}
if os.IsNotExist(err) {
r.URL.Path = "/"
filePath = "index.html"
}
if filePath == "index.html" {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
} else {
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
}
http.FileServer(http.FS(spaFS)).ServeHTTP(w, r)
}
}

View file

@ -140,7 +140,7 @@ func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
reader io.ReadCloser
err error
)
if r.file.Encrypted.IsSet() && r.file.Encrypted.Value {
if r.file.Encrypted.Value {
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
reader, err = cipher.DecryptDataSeek(r.ctx,

View file

@ -29,3 +29,20 @@ func ToFileOut(file models.File, extended bool) *api.File {
}
return res
}
func ToUploadOut(parts []models.Upload) []api.UploadPart {
res := []api.UploadPart{}
for _, part := range parts {
res = append(res, api.UploadPart{
Name: part.Name,
PartId: part.PartId,
ChannelId: part.ChannelID,
PartNo: part.PartNo,
Size: part.Size,
Encrypted: part.Encrypted,
Salt: api.NewOptString(part.Salt),
})
}
return res
}

View file

@ -134,7 +134,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
}
func (a *apiService) AuthSession(ctx context.Context, params api.AuthSessionParams) (api.AuthSessionRes, error) {
if !params.AccessToken.IsSet() {
if params.AccessToken.Value == "" {
return &api.AuthSessionNoContent{}, nil
}
claims, err := auth.VerifyUser(a.db, a.cache, a.cnf.JWT.Secret, params.AccessToken.Value)

View file

@ -49,7 +49,7 @@ func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher,
Size: document.Size,
Salt: file.Parts[i].Salt.Value,
}
if file.Encrypted.IsSet() && file.Encrypted.Value {
if file.Encrypted.Value {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
}
parts = append(parts, part)

View file

@ -196,7 +196,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
dbFile.Name = req.NewName.Or(file.Name)
dbFile.Size = utils.Ptr(file.Size.Value)
dbFile.Type = string(file.Type)
dbFile.MimeType = file.MimeType.Value
dbFile.MimeType = file.MimeType.Or(defaultContentType)
dbFile.Parts = datatypes.NewJSONSlice(newIds)
dbFile.UserID = userId
dbFile.Status = "active"
@ -226,7 +226,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
channelId int64
)
if fileIn.Path.IsSet() {
if fileIn.Path.Value != "" {
path = strings.TrimSpace(fileIn.Path.Value)
path = strings.ReplaceAll(path, "//", "/")
if path != "/" {
@ -234,7 +234,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
}
}
if path != "" && !fileIn.ParentId.IsSet() {
if path != "" && fileIn.ParentId.Value == "" {
parent, err = a.getFileFromPath(path, userId)
if err != nil {
return nil, &apiError{err: err, code: 404}
@ -243,7 +243,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
String: parent.Id,
Valid: true,
}
} else if fileIn.ParentId.IsSet() {
} else if fileIn.ParentId.Value != "" {
fileDB.ParentID = sql.NullString{
String: fileIn.ParentId.Value,
Valid: true,
@ -257,7 +257,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
fileDB.MimeType = "drive/folder"
fileDB.Parts = nil
} else if fileIn.Type == "file" {
if !fileIn.ChannelId.IsSet() {
if fileIn.ChannelId.Value == 0 {
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, &apiError{err: err}
@ -266,10 +266,18 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
channelId = fileIn.ChannelId.Value
}
fileDB.ChannelID = &channelId
fileDB.MimeType = fileIn.MimeType.Or("application/octet-stream")
fileDB.MimeType = fileIn.MimeType.Value
fileDB.Category = string(category.GetCategory(fileIn.Name))
if len(fileIn.Parts) > 0 {
fileDB.Parts = datatypes.NewJSONSlice(fileIn.Parts)
parts := []api.Part{}
for _, part := range fileIn.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
fileDB.Parts = datatypes.NewJSONSlice(parts)
}
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
}
@ -277,7 +285,7 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
fileDB.Type = string(fileIn.Type)
fileDB.UserID = userId
fileDB.Status = "active"
fileDB.Encrypted = fileIn.Encrypted.Or(false)
fileDB.Encrypted = fileIn.Encrypted.Value
if err := a.db.Create(&fileDB).Error; err != nil {
if database.IsKeyConflictErr(err) {
return nil, &apiError{err: errors.New("file already exists"), code: 409}
@ -292,7 +300,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
var fileShare models.FileShare
if req.Password.IsSet() {
if req.Password.Value != "" {
bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost)
if err != nil {
return &apiError{err: err}
@ -315,14 +323,14 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
userId, _ := auth.GetUser(ctx)
if !req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
}
if req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value != "" && len(req.Ids) == 0 {
if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil {
return &apiError{err: err}
}
} else if !req.Source.IsSet() && len(req.Ids) > 0 {
} else if req.Source.Value == "" && len(req.Ids) > 0 {
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
return &apiError{err: err}
}
@ -351,7 +359,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
var fileShareUpdate models.FileShare
if req.Password.IsSet() {
if req.Password.Value != "" {
bytes, err := bcrypt.GenerateFromPassword([]byte(req.Password.Value), bcrypt.MinCost)
if err != nil {
return &apiError{err: err}
@ -409,15 +417,15 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
userId, _ := auth.GetUser(ctx)
if !req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
}
if !req.Source.IsSet() && len(req.Ids) > 0 {
if req.Source.Value != "" && len(req.Ids) > 0 {
if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil {
return &apiError{err: err}
}
}
if req.Source.IsSet() && len(req.Ids) == 0 {
if req.Source.Value == "" && len(req.Ids) == 0 {
if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value,
req.Destination, userId).Error; err != nil {
return &apiError{err: err}
@ -466,13 +474,21 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
chain *gorm.DB
)
updateDb := models.File{}
if req.Name.IsSet() {
if req.Name.Value != "" {
updateDb.Name = req.Name.Value
}
if len(req.Parts) > 0 {
updateDb.Parts = datatypes.NewJSONSlice(req.Parts)
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updateDb.Parts = datatypes.NewJSONSlice(parts)
}
if req.Size.IsSet() {
if req.Size.Value != 0 {
updateDb.Size = utils.Ptr(req.Size.Value)
}
if req.UpdatedAt.IsSet() {
@ -503,7 +519,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
Size: utils.Ptr(req.Size),
}
if !req.ChannelId.IsSet() {
if req.ChannelId.Value == 0 {
channelId, err := getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return &apiError{err: err}
@ -513,7 +529,15 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
updatePayload.ChannelID = &req.ChannelId.Value
}
if len(req.Parts) > 0 {
updatePayload.Parts = datatypes.NewJSONSlice(req.Parts)
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updatePayload.Parts = datatypes.NewJSONSlice(parts)
}
err := a.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil {
@ -522,7 +546,7 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil {
return err
}
if req.UploadId.IsSet() {
if req.UploadId.Value != "" {
if err := tx.Where("upload_id = ?", req.UploadId.Value).Delete(&models.Upload{}).Error; err != nil {
return err
}

View file

@ -63,24 +63,24 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int
}
func (afb *fileQueryBuilder) applyListFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB {
if filesQuery.Path.IsSet() && !filesQuery.ParentId.IsSet() {
if filesQuery.Path.Value != "" && filesQuery.ParentId.Value == "" {
query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true)
}
if filesQuery.ParentId.IsSet() {
if filesQuery.ParentId.Value != "" {
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
}
return query
}
func (afb *fileQueryBuilder) applyFindFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB {
if filesQuery.DeepSearch.IsSet() && filesQuery.DeepSearch.Value && filesQuery.Query.IsSet() && filesQuery.Path.IsSet() {
if filesQuery.DeepSearch.Value && filesQuery.Query.Value != "" && filesQuery.Path.Value != "" {
query = query.Where("files.id in (select id from subdirs)")
}
if filesQuery.UpdatedAt.IsSet() {
if filesQuery.UpdatedAt.Value != "" {
query, _ = afb.applyDateFilters(query, filesQuery.UpdatedAt.Value)
}
if filesQuery.Query.IsSet() {
if filesQuery.Query.Value != "" {
query = afb.applySearchQuery(query, filesQuery)
}
@ -92,23 +92,24 @@ func (afb *fileQueryBuilder) applyFindFilters(query *gorm.DB, filesQuery *api.Fi
}
func (afb *fileQueryBuilder) applyFileSpecificFilters(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB {
if filesQuery.Name.IsSet() {
if filesQuery.Name.Value != "" {
query = query.Where("name = ?", filesQuery.Name.Value)
}
if filesQuery.ParentId.IsSet() {
if filesQuery.ParentId.Value != "" {
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
}
if !filesQuery.ParentId.IsSet() && filesQuery.Path.IsSet() && !filesQuery.Query.IsSet() {
query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true)
if filesQuery.ParentId.Value == "" && filesQuery.Path.Value != "" && filesQuery.Query.Value == "" {
query = query.Where("parent_id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))",
filesQuery.Path.Value, userId, true)
}
if filesQuery.Type.IsSet() {
if filesQuery.Type.Value != "" {
query = query.Where("type = ?", filesQuery.Type.Value)
}
if filesQuery.Shared.IsSet() && filesQuery.Shared.Value {
if filesQuery.Shared.Value {
query = query.Where("id in (SELECT file_id FROM teldrive.file_shares where user_id = ?)", userId)
}
@ -201,7 +202,7 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File
}
func (afb *fileQueryBuilder) buildSubqueryCTE(query *gorm.DB, filesQuery *api.FilesListParams, userId int64) *gorm.DB {
if filesQuery.DeepSearch.IsSet() && filesQuery.DeepSearch.Value && filesQuery.Query.IsSet() && filesQuery.Path.IsSet() {
if filesQuery.DeepSearch.Value && filesQuery.Query.Value != "" && filesQuery.Path.Value != "" {
return afb.db.Clauses(exclause.With{Recursive: true, CTEs: []exclause.CTE{{Name: "subdirs",
Subquery: exclause.Subquery{DB: afb.db.Model(&models.File{}).Select("id", "parent_id").
Where("id in (SELECT id FROM teldrive.get_file_from_path(?, ?, ?))", filesQuery.Path.Value, userId, true).

View file

@ -24,6 +24,7 @@ import (
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/telegram/uploader"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/pkg/mapper"
"github.com/tgdrive/teldrive/pkg/models"
)
@ -37,13 +38,13 @@ func (a *apiService) UploadsDelete(ctx context.Context, params api.UploadsDelete
}
func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPartsByIdParams) ([]api.UploadPart, error) {
parts := []api.UploadPart{}
parts := []models.Upload{}
if err := a.db.Model(&models.Upload{}).Order("part_no").Where("upload_id = ?", params.ID).
Where("created_at < ?", time.Now().UTC().Add(a.cnf.TG.Uploads.Retention)).
Find(&parts).Error; err != nil {
return nil, &apiError{err: err}
}
return parts, nil
return mapper.ToUploadOut(parts), nil
}
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
@ -74,7 +75,7 @@ func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsPa
return stats, nil
}
func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq, params api.UploadsUploadParams) (*api.UploadPart, error) {
func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadReqWithContentType, params api.UploadsUploadParams) (*api.UploadPart, error) {
var (
channelId int64
err error
@ -85,17 +86,17 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq
out api.UploadPart
)
if !params.Encrypted.IsSet() && a.cnf.TG.Uploads.EncryptionKey == "" {
if params.Encrypted.Value && a.cnf.TG.Uploads.EncryptionKey == "" {
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
}
userId, session := auth.GetUser(ctx)
fileStream := req.Data
fileStream := req.Content.Data
fileSize := params.ContentLength
if !params.ChannelId.IsSet() {
if params.ChannelId.Value == 0 {
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, err
@ -158,7 +159,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq
var salt string
if params.Encrypted.IsSet() {
if params.Encrypted.Value {
//gen random Salt
salt, _ = generateRandomSalt()
cipher, err := crypt.NewCipher(a.cnf.TG.Uploads.EncryptionKey, salt)
@ -166,7 +167,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq
return err
}
fileSize = crypt.EncryptedSize(fileSize)
fileStream, err = cipher.EncryptData(req.Data)
fileStream, err = cipher.EncryptData(fileStream)
if err != nil {
return err
}
@ -219,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req api.UploadsUploadReq
Size: fileSize,
PartNo: int(params.PartNo),
UserId: userId,
Encrypted: params.Encrypted.IsSet(),
Encrypted: params.Encrypted.Value,
Salt: salt,
}

View file

@ -246,10 +246,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd
channel := &models.Channel{UserID: userId, Selected: true}
if req.ChannelId.IsSet() {
if req.ChannelId.Value != 0 {
channel.ChannelID = req.ChannelId.Value
}
if req.ChannelName.IsSet() {
if req.ChannelName.Value != "" {
channel.ChannelName = req.ChannelName.Value
}