teldrive/pkg/services/share.go
2025-05-14 23:39:40 +05:30

151 lines
4.6 KiB
Go

package services
import (
"context"
"encoding/base64"
"errors"
"net/http"
"strings"
"time"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/appcontext"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/database"
"github.com/tgdrive/teldrive/pkg/mapper"
"github.com/tgdrive/teldrive/pkg/models"
"golang.org/x/crypto/bcrypt"
)
var (
ErrShareNotFound = errors.New("share not found")
ErrInvalidPassword = errors.New("invalid password")
ErrEmptyAuth = errors.New("empty auth")
ErrShareExpired = errors.New("share expired")
)
type fileShare struct {
models.FileShare
Type api.FileShareInfoType
Name string
Path string
}
func (a *apiService) shareGetById(id string) (*fileShare, error) {
var result []fileShare
if err := a.db.Model(&models.FileShare{}).Where("file_shares.id = ?", id).
Select("file_shares.*", "f.type", "f.name",
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(f.id))").
Joins("left join teldrive.files as f on f.id = file_shares.file_id").
Scan(&result).Error; err != nil {
return nil, &apiError{err: err}
}
if len(result) == 0 {
return nil, &apiError{err: ErrShareNotFound, code: http.StatusNotFound}
}
if result[0].ExpiresAt != nil && result[0].ExpiresAt.Before(time.Now().UTC()) {
return nil, &apiError{err: ErrShareExpired, code: http.StatusNotFound}
}
return &result[0], nil
}
func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetByIdParams) (*api.FileShareInfo, error) {
share, err := a.shareGetById(params.ID)
if err != nil {
return nil, err
}
res := &api.FileShareInfo{
Protected: share.Password != nil,
UserId: share.UserId,
Type: share.Type,
Name: share.Name,
}
if share.ExpiresAt != nil {
res.ExpiresAt = api.NewOptDateTime(*share.ExpiresAt)
}
return res, nil
}
func (a *apiService) SharesUnlock(ctx context.Context, req *api.ShareUnlock, params api.SharesUnlockParams) error {
var result []models.FileShare
if err := a.db.Model(&models.FileShare{}).Where("id = ?", params.ID).Find(&result).Error; err != nil {
return &apiError{err: err}
}
if len(result) == 0 {
return &apiError{err: ErrShareNotFound, code: http.StatusNotFound}
}
if err := bcrypt.CompareHashAndPassword([]byte(*result[0].Password), []byte(req.Password)); err != nil {
return &apiError{err: ErrInvalidPassword, code: http.StatusForbidden}
}
return nil
}
func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListFilesParams) (*api.FileList, error) {
c := ctx.(*appcontext.Context)
share, err := a.validFileShare(c.Request, params.ID)
if err != nil {
return nil, err
}
fileType := share.Type
if fileType == api.FileShareInfoTypeFolder {
queryBuilder := &fileQueryBuilder{db: a.db}
return queryBuilder.execute(&api.FilesListParams{
Path: api.NewOptString(share.Path + params.Path.Or("")),
Limit: params.Limit,
Page: params.Page,
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId)
} else {
var file models.File
if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil {
if database.IsRecordNotFoundErr(err) {
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
}
return nil, &apiError{err: err}
}
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)},
Meta: api.Meta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
}
}
func (a *apiService) validFileShare(r *http.Request, id string) (*fileShare, error) {
share, err := cache.FetchArg(a.cache, cache.Key("shares", id), 0, a.shareGetById, id)
if err != nil {
return nil, &apiError{err: err}
}
if share.Password != nil {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, &apiError{err: ErrEmptyAuth, code: http.StatusUnauthorized}
}
bytes, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic "))
if err != nil {
return nil, &apiError{err: err}
}
password := strings.Split(string(bytes), ":")[1]
if err := bcrypt.CompareHashAndPassword([]byte(*share.Password), []byte(password)); err != nil {
return nil, &apiError{err: ErrInvalidPassword, code: http.StatusUnauthorized}
}
}
return share, nil
}
func (a *apiService) SharesStream(ctx context.Context, params api.SharesStreamParams) (api.SharesStreamRes, error) {
return nil, nil
}