mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-03 21:14:28 +08:00
151 lines
4.6 KiB
Go
151 lines
4.6 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/tgdrive/teldrive/internal/api"
|
|
"github.com/tgdrive/teldrive/internal/appcontext"
|
|
"github.com/tgdrive/teldrive/internal/cache"
|
|
"github.com/tgdrive/teldrive/internal/database"
|
|
"github.com/tgdrive/teldrive/pkg/mapper"
|
|
"github.com/tgdrive/teldrive/pkg/models"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var (
|
|
ErrShareNotFound = errors.New("share not found")
|
|
ErrInvalidPassword = errors.New("invalid password")
|
|
ErrEmptyAuth = errors.New("empty auth")
|
|
ErrShareExpired = errors.New("share expired")
|
|
)
|
|
|
|
type fileShare struct {
|
|
models.FileShare
|
|
Type api.FileShareInfoType
|
|
Name string
|
|
Path string
|
|
}
|
|
|
|
func (a *apiService) shareGetById(id string) (*fileShare, error) {
|
|
var result []fileShare
|
|
|
|
if err := a.db.Model(&models.FileShare{}).Where("file_shares.id = ?", id).
|
|
Select("file_shares.*", "f.type", "f.name",
|
|
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(f.id))").
|
|
Joins("left join teldrive.files as f on f.id = file_shares.file_id").
|
|
Scan(&result).Error; err != nil {
|
|
return nil, &apiError{err: err}
|
|
}
|
|
|
|
if len(result) == 0 {
|
|
return nil, &apiError{err: ErrShareNotFound, code: http.StatusNotFound}
|
|
}
|
|
|
|
if result[0].ExpiresAt != nil && result[0].ExpiresAt.Before(time.Now().UTC()) {
|
|
return nil, &apiError{err: ErrShareExpired, code: http.StatusNotFound}
|
|
}
|
|
|
|
return &result[0], nil
|
|
}
|
|
|
|
func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetByIdParams) (*api.FileShareInfo, error) {
|
|
share, err := a.shareGetById(params.ID)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := &api.FileShareInfo{
|
|
Protected: share.Password != nil,
|
|
UserId: share.UserId,
|
|
Type: share.Type,
|
|
Name: share.Name,
|
|
}
|
|
if share.ExpiresAt != nil {
|
|
res.ExpiresAt = api.NewOptDateTime(*share.ExpiresAt)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (a *apiService) SharesUnlock(ctx context.Context, req *api.ShareUnlock, params api.SharesUnlockParams) error {
|
|
var result []models.FileShare
|
|
|
|
if err := a.db.Model(&models.FileShare{}).Where("id = ?", params.ID).Find(&result).Error; err != nil {
|
|
return &apiError{err: err}
|
|
}
|
|
|
|
if len(result) == 0 {
|
|
return &apiError{err: ErrShareNotFound, code: http.StatusNotFound}
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(*result[0].Password), []byte(req.Password)); err != nil {
|
|
return &apiError{err: ErrInvalidPassword, code: http.StatusForbidden}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListFilesParams) (*api.FileList, error) {
|
|
c := ctx.(*appcontext.Context)
|
|
share, err := a.validFileShare(c.Request, params.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fileType := share.Type
|
|
|
|
if fileType == api.FileShareInfoTypeFolder {
|
|
queryBuilder := &fileQueryBuilder{db: a.db}
|
|
return queryBuilder.execute(&api.FilesListParams{
|
|
Path: api.NewOptString(share.Path + params.Path.Or("")),
|
|
Limit: params.Limit,
|
|
Page: params.Page,
|
|
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
|
|
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
|
|
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
|
|
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId)
|
|
} else {
|
|
var file models.File
|
|
if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil {
|
|
if database.IsRecordNotFoundErr(err) {
|
|
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
|
|
}
|
|
return nil, &apiError{err: err}
|
|
}
|
|
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)},
|
|
Meta: api.Meta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
|
|
}
|
|
|
|
}
|
|
func (a *apiService) validFileShare(r *http.Request, id string) (*fileShare, error) {
|
|
|
|
share, err := cache.FetchArg(a.cache, cache.Key("shares", id), 0, a.shareGetById, id)
|
|
|
|
if err != nil {
|
|
return nil, &apiError{err: err}
|
|
}
|
|
|
|
if share.Password != nil {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
return nil, &apiError{err: ErrEmptyAuth, code: http.StatusUnauthorized}
|
|
}
|
|
bytes, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic "))
|
|
if err != nil {
|
|
return nil, &apiError{err: err}
|
|
}
|
|
password := strings.Split(string(bytes), ":")[1]
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(*share.Password), []byte(password)); err != nil {
|
|
return nil, &apiError{err: ErrInvalidPassword, code: http.StatusUnauthorized}
|
|
}
|
|
|
|
}
|
|
return share, nil
|
|
}
|
|
|
|
func (a *apiService) SharesStream(ctx context.Context, params api.SharesStreamParams) (api.SharesStreamRes, error) {
|
|
return nil, nil
|
|
}
|