refactor: standardize field naming and improve cache key generation

This commit is contained in:
Bhunter 2025-01-13 08:07:46 +01:00
parent 299f4fa7ec
commit f0187f4052
26 changed files with 516 additions and 408 deletions

4
go.mod
View file

@ -8,13 +8,12 @@ require (
github.com/go-chi/chi/v5 v5.2.0
github.com/go-chi/cors v1.2.1
github.com/go-co-op/gocron v1.37.0
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/gotd/contrib v0.21.0
github.com/gotd/td v0.117.0
github.com/iyear/connectproxy v0.1.1
github.com/mitchellh/go-homedir v1.1.0
github.com/mitchellh/mapstructure v1.5.0
github.com/ogen-go/ogen v1.8.1
github.com/redis/go-redis/v9 v9.7.0
github.com/spf13/cobra v1.8.1
@ -50,6 +49,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect

4
go.sum
View file

@ -79,6 +79,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
@ -148,8 +150,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=

View file

@ -42,10 +42,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
}
func GetUser(c context.Context) (int64, string) {
func GetUser(c context.Context) int64 {
authUser, _ := c.Value(authKey).(*types.JWTClaims)
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
return userId, authUser.TgSession
return userId
}
func GetJWTUser(c context.Context) *types.JWTClaims {

View file

@ -2,6 +2,10 @@ package cache
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
@ -29,8 +33,16 @@ func NewCache(ctx context.Context, conf *config.CacheConfig) Cacher {
cacher = NewMemoryCache(conf.MaxSize)
} else {
cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{
Addr: conf.RedisAddr,
Password: conf.RedisPass,
Addr: conf.RedisAddr,
Password: conf.RedisPass,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolSize: 10,
MinIdleConns: 5,
MaxIdleConns: 10,
ConnMaxIdleTime: 5 * time.Minute,
ConnMaxLifetime: 1 * time.Hour,
}))
}
return cacher
@ -119,3 +131,61 @@ func (r *RedisCache) Delete(keys ...string) error {
}
return r.client.Del(r.ctx, keys...).Err()
}
func Fetch[T any](cache Cacher, key string, expiration time.Duration, fn func() (T, error)) (T, error) {
var zero, value T
err := cache.Get(key, &value)
if err != nil {
if errors.Is(err, freecache.ErrNotFound) || errors.Is(err, redis.Nil) {
value, err = fn()
if err != nil {
return zero, err
}
cache.Set(key, &value, expiration)
return value, nil
}
return zero, err
}
return value, nil
}
func Key(args ...interface{}) string {
parts := make([]string, len(args))
for i, arg := range args {
parts[i] = formatValue(arg)
}
return strings.Join(parts, ":")
}
func formatValue(v interface{}) string {
if v == nil {
return "nil"
}
val := reflect.ValueOf(v)
switch val.Kind() {
case reflect.Ptr:
if val.IsNil() {
return "nil"
}
return formatValue(val.Elem().Interface())
case reflect.Array, reflect.Slice:
parts := make([]string, val.Len())
for i := 0; i < val.Len(); i++ {
parts[i] = formatValue(val.Index(i).Interface())
}
return fmt.Sprintf("[%s]", strings.Join(parts, ","))
case reflect.Map:
parts := make([]string, 0, val.Len())
for _, key := range val.MapKeys() {
k := formatValue(key.Interface())
v := formatValue(val.MapIndex(key).Interface())
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
}
return fmt.Sprintf("{%s}", strings.Join(parts, ","))
case reflect.Struct:
return fmt.Sprintf("%+v", v)
default:
return fmt.Sprintf("%v", v)
}
}

View file

@ -5,16 +5,16 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/pkg/models"
)
func TestCache(t *testing.T) {
var value = api.File{
var value = models.File{
Name: "file.jpeg",
Type: "file",
}
var result api.File
var result models.File
cache := NewMemoryCache(1 * 1024 * 1024)
@ -25,3 +25,149 @@ func TestCache(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, result, value)
}
func TestKey(t *testing.T) {
tests := []struct {
name string
args []interface{}
expected string
}{
{
name: "simple strings",
args: []interface{}{"user", "123"},
expected: "user:123",
},
{
name: "mixed types",
args: []interface{}{"cache", 123, true},
expected: "cache:123:true",
},
{
name: "with nil",
args: []interface{}{"key", nil, "value"},
expected: "key:nil:value",
},
{
name: "empty args",
args: []interface{}{},
expected: "",
},
{
name: "single arg",
args: []interface{}{"solo"},
expected: "solo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Key(tt.args...)
if result != tt.expected {
t.Errorf("Key() = %v, want %v", result, tt.expected)
}
})
}
}
func TestFormatValue(t *testing.T) {
type testStruct struct {
Name string
Age int
}
tests := []struct {
name string
input interface{}
expected string
}{
{
name: "nil value",
input: nil,
expected: "nil",
},
{
name: "string",
input: "test",
expected: "test",
},
{
name: "integer",
input: 123,
expected: "123",
},
{
name: "boolean",
input: true,
expected: "true",
},
{
name: "slice of strings",
input: []string{"a", "b", "c"},
expected: "[a,b,c]",
},
{
name: "slice of ints",
input: []int{1, 2, 3},
expected: "[1,2,3]",
},
{
name: "empty slice",
input: []string{},
expected: "[]",
},
{
name: "map string to string",
input: map[string]string{"a": "1", "b": "2"},
expected: "{a=1,b=2}",
},
{
name: "empty map",
input: map[string]string{},
expected: "{}",
},
{
name: "struct",
input: testStruct{Name: "John", Age: 30},
expected: "{Name:John Age:30}",
},
{
name: "pointer to string",
input: func() interface{} { s := "test"; return &s }(),
expected: "test",
},
{
name: "nil pointer",
input: func() interface{} { var s *string; return s }(),
expected: "nil",
},
{
name: "nested slice",
input: [][]int{{1, 2}, {3, 4}},
expected: "[[1,2],[3,4]]",
},
{
name: "complex mixed structure",
input: struct {
ID int
Tags []string
Meta map[string]interface{}
Valid bool
}{
ID: 1,
Tags: []string{"a", "b"},
Meta: map[string]interface{}{"count": 42},
Valid: true,
},
expected: "{ID:1 Tags:[a b] Meta:map[count:42] Valid:true}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatValue(tt.input)
if result != tt.expected {
t.Errorf("formatValue() = %v, want %v", result, tt.expected)
}
})
}
}

View file

@ -2,13 +2,13 @@ package config
import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/mitchellh/go-homedir"
"github.com/mitchellh/mapstructure"
"github.com/go-viper/mapstructure/v2"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
@ -144,7 +144,7 @@ func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error {
if cfgFile != "" {
cl.v.SetConfigFile(cfgFile)
} else {
home, err := homedir.Dir()
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("error getting home directory: %v", err)
}

View file

@ -2,14 +2,13 @@ package reader
import (
"context"
"fmt"
"io"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/crypt"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
)
@ -20,7 +19,7 @@ type Range struct {
type LinearReader struct {
ctx context.Context
file *api.File
file *models.File
parts []types.Part
ranges []Range
pos int
@ -52,7 +51,7 @@ func calculatePartByteRanges(start, end, partSize int64) []Range {
func NewLinearReader(ctx context.Context,
client *tg.Client,
cache cache.Cacher,
file *api.File,
file *models.File,
parts []types.Part,
start,
end int64,
@ -61,7 +60,7 @@ func NewLinearReader(ctx context.Context,
) (io.ReadCloser, error) {
size := parts[0].Size
if file.Encrypted.Value {
if file.Encrypted {
size = parts[0].DecryptedSize
}
r := &LinearReader{
@ -129,22 +128,22 @@ func (r *LinearReader) moveToNextPart() error {
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
currentRange := r.ranges[r.pos]
partID := r.parts[currentRange.PartNo].ID
partId := r.parts[currentRange.PartNo].ID
chunkSrc := &chunkSource{
channelID: r.file.ChannelId.Value,
partID: partID,
channelId: *r.file.ChannelId,
partId: partId,
client: r.client,
concurrency: r.concurrency,
cache: r.cache,
key: fmt.Sprintf("files:location:%s:%d", r.file.ID.Value, partID),
key: cache.Key("files", "location", r.file.ID, partId),
}
var (
reader io.ReadCloser
err error
)
if r.file.Encrypted.Value {
if r.file.Encrypted {
salt := r.parts[r.ranges[r.pos].PartNo].Salt
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
reader, err = cipher.DecryptDataSeek(r.ctx,

View file

@ -25,8 +25,8 @@ type ChunkSource interface {
}
type chunkSource struct {
channelID int64
partID int64
channelId int64
partId int64
concurrency int
client *tg.Client
key string
@ -46,7 +46,7 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
err = c.cache.Get(c.key, location)
if err != nil {
location, err = tgc.GetLocation(ctx, c.client, c.channelID, c.partID)
location, err = tgc.GetLocation(ctx, c.client, c.channelId, c.partId)
if err != nil {
return nil, err
}

View file

@ -5,14 +5,13 @@ import (
"path/filepath"
"time"
"github.com/mitchellh/go-homedir"
"github.com/tgdrive/teldrive/internal/utils"
"go.etcd.io/bbolt"
)
func NewBoltDB(sessionFile string) (*bbolt.DB, error) {
if sessionFile == "" {
dir, err := homedir.Dir()
dir, err := os.UserHomeDir()
if err != nil {
dir = utils.ExecutableDir()
} else {

View file

@ -12,13 +12,14 @@ import (
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/utils"
"github.com/tgdrive/teldrive/pkg/types"
"go.etcd.io/bbolt"
"golang.org/x/sync/errgroup"
)
var (
ErrInValidChannelID = errors.New("invalid channel id")
ErrInValidChannelId = errors.New("invalid channel id")
ErrInvalidChannelMessages = errors.New("invalid channel messages")
)
@ -33,7 +34,7 @@ func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*t
}
if len(channels.GetChats()) == 0 {
return nil, ErrInValidChannelID
return nil, ErrInValidChannelId
}
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
}
@ -71,15 +72,11 @@ func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int6
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
msgIds := []tg.InputMessageClass{}
for _, id := range ids {
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
}
messageRequest := tg.ChannelsGetMessagesRequest{
Channel: channel,
ID: msgIds,
ID: utils.Map(ids, func(id int) tg.InputMessageClass {
return &tg.InputMessageID{ID: id}
}),
}
res, err := client.ChannelsGetMessages(ctx, &messageRequest)

View file

@ -17,21 +17,21 @@ func NewBotWorker() *BotWorker {
}
}
func (w *BotWorker) Set(bots []string, channelID int64) {
func (w *BotWorker) Set(bots []string, channelId int64) {
w.mu.Lock()
defer w.mu.Unlock()
if _, ok := w.bots[channelID]; ok {
if _, ok := w.bots[channelId]; ok {
return
}
w.bots[channelID] = bots
w.currIdx[channelID] = 0
w.bots[channelId] = bots
w.currIdx[channelId] = 0
}
func (w *BotWorker) Next(channelID int64) (string, int) {
func (w *BotWorker) Next(channelId int64) (string, int) {
w.mu.RLock()
defer w.mu.RUnlock()
bots := w.bots[channelID]
index := w.currIdx[channelID]
w.currIdx[channelID] = (index + 1) % len(bots)
bots := w.bots[channelId]
index := w.currIdx[channelId]
w.currIdx[channelId] = (index + 1) % len(bots)
return bots[index], index
}

View file

@ -5,73 +5,18 @@ import (
"path/filepath"
"regexp"
"strings"
"time"
"reflect"
"unicode"
)
func CamelToPascalCase(input string) string {
var result strings.Builder
upperNext := true
for _, char := range input {
if unicode.IsLetter(char) || unicode.IsDigit(char) {
if upperNext {
result.WriteRune(unicode.ToUpper(char))
upperNext = false
} else {
result.WriteRune(char)
}
} else {
upperNext = true
}
}
return result.String()
}
func CamelToSnake(input string) string {
re := regexp.MustCompile("([a-z0-9])([A-Z])")
snake := re.ReplaceAllString(input, "${1}_${2}")
return strings.ToLower(snake)
}
func GetField(v interface{}, field string) string {
r := reflect.ValueOf(v)
f := reflect.Indirect(r).FieldByName(field)
fieldValue := f.Interface()
switch v := fieldValue.(type) {
case string:
return v
case time.Time:
return v.Format(time.RFC3339)
default:
return ""
}
}
func Ptr[T any](t T) *T {
return &t
}
func BoolPointer(b bool) *bool {
return &b
}
func IntPointer(b int) *int {
return &b
}
func Int64Pointer(b int64) *int64 {
return &b
}
func StringPointer(b string) *string {
return &b
}
func PathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
@ -84,8 +29,34 @@ func PathExists(path string) (bool, error) {
}
func ExecutableDir() string {
path, _ := os.Executable()
return filepath.Dir(path)
}
func Filter[T any](slice []T, predicate func(T) bool) []T {
var result []T
for _, v := range slice {
if predicate(v) {
result = append(result, v)
}
}
return result
}
func Map[T any, U any](slice []T, mapper func(T) U) []U {
var result []U
for _, v := range slice {
result = append(result, mapper(v))
}
return result
}
func Find[T any](slice []T, predicate func(T) bool) (T, bool) {
for _, v := range slice {
if predicate(v) {
return v, true
}
}
var zero T
return zero, false
}

View file

@ -2,47 +2,41 @@ package mapper
import (
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/utils"
"github.com/tgdrive/teldrive/pkg/models"
)
func ToFileOut(file models.File, extended bool) *api.File {
func ToFileOut(file models.File) *api.File {
res := &api.File{
ID: api.NewOptString(file.Id),
ID: api.NewOptString(file.ID),
Name: file.Name,
Type: api.FileType(file.Type),
MimeType: api.NewOptString(file.MimeType),
Encrypted: api.NewOptBool(file.Encrypted),
ParentId: api.NewOptString(file.ParentID.String),
UpdatedAt: api.NewOptDateTime(file.UpdatedAt),
}
if file.ParentId != nil {
res.ParentId = api.NewOptString(*file.ParentId)
}
if file.Size != nil {
res.Size = api.NewOptInt64(*file.Size)
}
if file.Category != "" {
res.Category = api.NewOptFileCategory(api.FileCategory(file.Category))
}
if extended {
res.Parts = file.Parts
if file.ChannelID != nil {
res.ChannelId = api.NewOptInt64(*file.ChannelID)
}
}
return res
}
func ToUploadOut(parts []models.Upload) []api.UploadPart {
res := []api.UploadPart{}
for _, part := range parts {
res = append(res, api.UploadPart{
return utils.Map(parts, func(part models.Upload) api.UploadPart {
return api.UploadPart{
Name: part.Name,
PartId: part.PartId,
ChannelId: part.ChannelID,
ChannelId: part.ChannelId,
PartNo: part.PartNo,
Size: part.Size,
Encrypted: part.Encrypted,
Salt: api.NewOptString(part.Salt),
})
}
return res
}
})
}

View file

@ -2,8 +2,8 @@ package models
type Bot struct {
Token string `gorm:"type:text;primaryKey"`
UserID int64 `gorm:"type:bigint"`
BotID int64 `gorm:"type:bigint"`
UserId int64 `gorm:"type:bigint"`
BotId int64 `gorm:"type:bigint"`
BotUserName string `gorm:"type:text"`
ChannelID int64 `gorm:"type:bigint"`
ChannelId int64 `gorm:"type:bigint"`
}

View file

@ -1,8 +1,8 @@
package models
type Channel struct {
ChannelID int64 `gorm:"type:bigint;primaryKey"`
ChannelId int64 `gorm:"type:bigint;primaryKey"`
ChannelName string `gorm:"type:text"`
UserID int64 `gorm:"type:bigint;"`
UserId int64 `gorm:"type:bigint;"`
Selected bool `gorm:"type:boolean;"`
}

View file

@ -1,7 +1,6 @@
package models
import (
"database/sql"
"time"
"github.com/tgdrive/teldrive/internal/api"
@ -9,18 +8,18 @@ import (
)
type File struct {
Id string `gorm:"type:uuid;primaryKey;default:uuid7()"`
ID string `gorm:"type:uuid;primaryKey;default:uuid7()"`
Name string `gorm:"type:text;not null"`
Type string `gorm:"type:text;not null"`
MimeType string `gorm:"type:text;not null"`
Size *int64 `gorm:"type:bigint"`
Category string `gorm:"type:text"`
Encrypted bool `gorm:"default:false"`
UserID int64 `gorm:"type:bigint;not null"`
UserId int64 `gorm:"type:bigint;not null"`
Status string `gorm:"type:text"`
ParentID sql.NullString `gorm:"type:uuid;index"`
ParentId *string `gorm:"type:uuid;index"`
Parts datatypes.JSONSlice[api.Part] `gorm:"type:jsonb"`
ChannelID *int64 `gorm:"type:bigint"`
ChannelId *int64 `gorm:"type:bigint"`
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
}

View file

@ -6,10 +6,10 @@ import (
type FileShare struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4();primary_key"`
FileID string `gorm:"type:uuid;not null"`
FileId string `gorm:"type:uuid;not null"`
Password *string `gorm:"type:text"`
ExpiresAt *time.Time `gorm:"type:timestamp"`
CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
UserID int64 `gorm:"type:bigint;not null"`
UserId int64 `gorm:"type:bigint;not null"`
}

View file

@ -12,7 +12,7 @@ type Upload struct {
PartId int `gorm:"type:integer"`
Encrypted bool `gorm:"default:false"`
Salt string `gorm:"type:text"`
ChannelID int64 `gorm:"type:bigint"`
ChannelId int64 `gorm:"type:bigint"`
Size int64 `gorm:"type:bigint"`
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
}

View file

@ -25,6 +25,7 @@ import (
"github.com/gotd/td/tgerr"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/pkg/models"
@ -80,7 +81,7 @@ func (a *apiService) AuthLogin(ctx context.Context, session *api.SessionCreate)
Name: "root",
Type: "folder",
MimeType: "drive/folder",
UserID: session.UserId,
UserId: session.UserId,
Status: "active",
Parts: nil,
}
@ -129,7 +130,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
return err
})
a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{})
a.cache.Delete(fmt.Sprintf("sessions:%s", authUser.Hash))
a.cache.Delete(cache.Key("sessions", authUser.Hash))
return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil
}
@ -360,7 +361,7 @@ func pack32BinaryIP4(ip4Address string) []byte {
return buf.Bytes()
}
func generateTgSession(dcID int, authKey []byte, port int) string {
func generateTgSession(dcId int, authKey []byte, port int) string {
dcMaps := map[int]string{
1: "149.154.175.53",
@ -370,8 +371,8 @@ func generateTgSession(dcID int, authKey []byte, port int) string {
5: "91.108.56.130",
}
dcIDByte := byte(dcID)
serverAddressBytes := pack32BinaryIP4(dcMaps[dcID])
dcIDByte := byte(dcId)
serverAddressBytes := pack32BinaryIP4(dcMaps[dcId])
portByte := make([]byte, 2)
binary.BigEndian.PutUint16(portByte, uint16(port))

View file

@ -2,106 +2,88 @@ package services
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-faster/errors"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/crypt"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/internal/utils"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
"go.uber.org/zap"
"gorm.io/gorm"
)
func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *api.File) ([]types.Part, error) {
func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) {
return cache.Fetch(c, cache.Key("files", "messages", file.ID), 60*time.Minute, func() ([]types.Part, error) {
messages, err := tgc.GetMessages(ctx, client.API(), utils.Map(file.Parts, func(part api.Part) int {
return part.ID
}), *file.ChannelId)
parts := []types.Part{}
key := fmt.Sprintf("files:messages:%s", file.ID.Value)
err := cache.Get(key, &parts)
if err == nil {
if err != nil {
return nil, err
}
parts := []types.Part{}
for i, message := range messages {
switch item := message.(type) {
case *tg.Message:
media, ok := item.Media.(*tg.MessageMediaDocument)
if !ok {
continue
}
document, ok := media.Document.(*tg.Document)
if !ok {
continue
}
part := types.Part{
ID: int64(file.Parts[i].ID),
Size: document.Size,
Salt: file.Parts[i].Salt.Value,
}
if file.Encrypted {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
}
parts = append(parts, part)
}
}
if len(parts) != len(file.Parts) {
msg := "file parts mismatch"
logging.FromContext(ctx).Error(msg, zap.String("name", file.Name),
zap.Int("expected", len(file.Parts)), zap.Int("actual", len(parts)))
return nil, errors.New(msg)
}
return parts, nil
}
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value)
if err != nil {
return nil, err
}
for i, message := range messages {
item := message.(*tg.Message)
media := item.Media.(*tg.MessageMediaDocument)
document := media.Document.(*tg.Document)
part := types.Part{
ID: int64(file.Parts[i].ID),
Size: document.Size,
Salt: file.Parts[i].Salt.Value,
}
if file.Encrypted.Value {
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
}
parts = append(parts, part)
}
cache.Set(key, &parts, 60*time.Minute)
return parts, nil
})
}
func getDefaultChannel(db *gorm.DB, cache cache.Cacher, userID int64) (int64, error) {
var channelId int64
key := fmt.Sprintf("users:channel:%d", userID)
err := cache.Get(key, &channelId)
if err == nil {
return channelId, nil
}
var channelIds []int64
db.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
Pluck("channel_id", &channelIds)
if len(channelIds) == 1 {
channelId = channelIds[0]
cache.Set(key, channelId, 0)
}
if channelId == 0 {
return channelId, errors.New("default channel not set")
}
return channelId, nil
func getDefaultChannel(db *gorm.DB, c cache.Cacher, userId int64) (int64, error) {
return cache.Fetch(c, cache.Key("users", "channel", userId), 0, func() (int64, error) {
var channelIds []int64
if err := db.Model(&models.Channel{}).Where("user_id = ?", userId).Where("selected = ?", true).
Pluck("channel_id", &channelIds).Error; err != nil {
return 0, err
}
if len(channelIds) == 0 {
return 0, fmt.Errorf("no default channel found for user %d", userId)
}
return channelIds[0], nil
})
}
func getBotsToken(db *gorm.DB, cache cache.Cacher, userID, channelId int64) ([]string, error) {
var bots []string
key := fmt.Sprintf("users:bots:%d:%d", userID, channelId)
err := cache.Get(key, &bots)
if err == nil {
func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]string, error) {
return cache.Fetch(c, cache.Key("users", "bots", userId, channelId), 0, func() ([]string, error) {
var bots []string
if err := db.Model(&models.Bot{}).Where("user_id = ?", userId).
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
return nil, err
}
return bots, nil
}
if err := db.Model(&models.Bot{}).Where("user_id = ?", userID).
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
return nil, err
}
cache.Set(key, &bots, 0)
return bots, nil
})
}

View file

@ -3,7 +3,6 @@ package services
import (
"context"
"crypto/rand"
"database/sql"
"encoding/binary"
"errors"
"fmt"
@ -20,6 +19,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/category"
"github.com/tgdrive/teldrive/internal/database"
"github.com/tgdrive/teldrive/internal/http_range"
@ -71,7 +71,7 @@ func randInt64() (int64, error) {
b := &buffer{Buf: buf[:]}
return b.long()
}
func isUUID(str string) bool {
func isUUId(str string) bool {
_, err := uuid.Parse(str)
return err == nil
}
@ -97,10 +97,10 @@ func (a *apiService) getFileFromPath(path string, userId int64) (*models.File, e
}
func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStats, error) {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var stats []api.CategoryStats
if err := a.db.Model(&models.File{}).Select("category", "COUNT(*) as total_files", "coalesce(SUM(size),0) as total_size").
Where(&models.File{UserID: userId, Type: "file", Status: "active"}).
Where(&models.File{UserId: userId, Type: "file", Status: "active"}).
Order("category ASC").Group("category").Find(&stats).Error; err != nil {
return nil, &apiError{err: err}
}
@ -110,9 +110,9 @@ func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStat
func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params api.FilesCopyParams) (*api.File, error) {
userId, session := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
var res []models.File
@ -133,12 +133,9 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
}
err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
ids := []int{}
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
}
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID)
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelId)
if err != nil {
return err
@ -198,13 +195,13 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
}
var parentId string
if !isUUID(req.Destination) {
if !isUUId(req.Destination) {
var destRes []models.File
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
Scan(&destRes).Error; err != nil {
return nil, &apiError{err: err}
}
parentId = destRes[0].Id
parentId = destRes[0].ID
} else {
parentId = req.Destination
}
@ -218,13 +215,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
if len(newIds) > 0 {
dbFile.Parts = datatypes.NewJSONSlice(newIds)
}
dbFile.UserID = userId
dbFile.UserId = userId
dbFile.Status = "active"
dbFile.ParentID = sql.NullString{
String: parentId,
Valid: true,
}
dbFile.ChannelID = &channelId
dbFile.ParentId = utils.Ptr(parentId)
dbFile.ChannelId = &channelId
dbFile.Encrypted = file.Encrypted
dbFile.Category = string(file.Category)
if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
@ -237,11 +231,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
return nil, &apiError{err: err}
}
return mapper.ToFileOut(dbFile, false), nil
return mapper.ToFileOut(dbFile), nil
}
func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.File, error) {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var (
fileDB models.File
@ -267,15 +261,9 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
if err != nil {
return nil, &apiError{err: err, code: 404}
}
fileDB.ParentID = sql.NullString{
String: parent.Id,
Valid: true,
}
fileDB.ParentId = utils.Ptr(parent.ID)
} else if fileIn.ParentId.Value != "" {
fileDB.ParentID = sql.NullString{
String: fileIn.ParentId.Value,
Valid: true,
}
fileDB.ParentId = utils.Ptr(fileIn.ParentId.Value)
}
@ -291,25 +279,17 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
} else {
channelId = fileIn.ChannelId.Value
}
fileDB.ChannelID = &channelId
fileDB.ChannelId = &channelId
fileDB.MimeType = fileIn.MimeType.Value
fileDB.Category = string(category.GetCategory(fileIn.Name))
if len(fileIn.Parts) > 0 {
parts := []api.Part{}
for _, part := range fileIn.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
fileDB.Parts = datatypes.NewJSONSlice(parts)
fileDB.Parts = datatypes.NewJSONSlice(mapParts(fileIn.Parts))
}
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
}
fileDB.Name = fileIn.Name
fileDB.Type = string(fileIn.Type)
fileDB.UserID = userId
fileDB.UserId = userId
fileDB.Status = "active"
fileDB.Encrypted = fileIn.Encrypted.Value
if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() {
@ -323,11 +303,11 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
}
return nil, &apiError{err: err}
}
return mapper.ToFileOut(fileDB, false), nil
return mapper.ToFileOut(fileDB), nil
}
func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCreate, params api.FilesCreateShareParams) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var fileShare models.FileShare
@ -339,11 +319,11 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
fileShare.Password = utils.Ptr(string(bytes))
}
fileShare.FileID = params.ID
fileShare.FileId = params.ID
if req.ExpiresAt.IsSet() {
fileShare.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
}
fileShare.UserID = userId
fileShare.UserId = userId
if err := a.db.Create(&fileShare).Error; err != nil {
return &apiError{err: err}
@ -353,7 +333,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
}
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
return &apiError{err: err}
}
@ -362,7 +342,7 @@ func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error
}
func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDeleteShareParams) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var deletedShare models.FileShare
@ -371,14 +351,14 @@ func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDelet
return &apiError{err: err}
}
if deletedShare.ID != "" {
a.cache.Delete(fmt.Sprintf("shares:%s", deletedShare.ID))
a.cache.Delete(cache.Key("shared", deletedShare.ID))
}
return nil
}
func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreate, params api.FilesEditShareParams) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var fileShareUpdate models.FileShare
@ -387,7 +367,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
if err != nil {
return &apiError{err: err}
}
fileShareUpdate.Password = utils.StringPointer(string(bytes))
fileShareUpdate.Password = utils.Ptr(string(bytes))
}
if req.ExpiresAt.IsSet() {
fileShareUpdate.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
@ -403,26 +383,25 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
func (a *apiService) FilesGetById(ctx context.Context, params api.FilesGetByIdParams) (*api.File, error) {
var result []fullFileDB
notFoundResponse := &apiError{err: errors.New("file not found"), code: 404}
if err := a.db.Model(&models.File{}).Select("*",
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(id))").
Where("id = ?", params.ID).Scan(&result).Error; err != nil {
if database.IsRecordNotFoundErr(err) {
return nil, notFoundResponse
}
return nil, &apiError{err: err}
}
if len(result) == 0 {
return nil, notFoundResponse
return nil, &apiError{err: errors.New("file not found"), code: 404}
}
res := mapper.ToFileOut(result[0].File, true)
res := mapper.ToFileOut(result[0].File)
res.Path = api.NewOptString(result[0].Path)
if result[0].ChannelId != nil {
res.ChannelId = api.NewOptInt64(*result[0].ChannelId)
}
return res, nil
}
func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) (*api.FileList, error) {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
queryBuilder := &fileQueryBuilder{db: a.db}
@ -430,7 +409,7 @@ func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams)
}
func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
if err := a.db.Exec("select * from teldrive.create_directories(?, ?)", userId, req.Path).Error; err != nil {
return &apiError{err: err}
@ -439,18 +418,18 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
}
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
items := pgtype.Array[string]{
Elements: req.Ids,
Valid: true,
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
}
if !isUUID(req.Destination) {
if !isUUId(req.Destination) {
r, err := a.getFileFromPath(req.Destination, userId)
if err != nil {
return &apiError{err: err}
}
req.Destination = r.Id
req.Destination = r.ID
}
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
Update("parent_id", req.Destination).Error; err != nil {
@ -462,7 +441,7 @@ func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
}
func (a *apiService) FilesShareByid(ctx context.Context, params api.FilesShareByidParams) (*api.FileShare, error) {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var result []models.FileShare
notFoundErr := &apiError{err: errors.New("invalid share"), code: 404}
@ -500,15 +479,7 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
updateDb.Name = req.Name.Value
}
if len(req.Parts) > 0 {
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updateDb.Parts = datatypes.NewJSONSlice(parts)
updateDb.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
}
if req.Size.Value != 0 {
updateDb.Size = utils.Ptr(req.Size.Value)
@ -523,17 +494,18 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
return nil, &apiError{err: err}
}
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
a.cache.Delete(cache.Key("files", params.ID))
file := models.File{}
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
return nil, &apiError{err: err}
}
return mapper.ToFileOut(file, false), nil
return mapper.ToFileOut(file), nil
}
func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var file models.File
@ -545,29 +517,18 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
if err != nil {
return &apiError{err: err}
}
updatePayload.ChannelID = &channelId
updatePayload.ChannelId = &channelId
} else {
updatePayload.ChannelID = &req.ChannelId.Value
updatePayload.ChannelId = &req.ChannelId.Value
}
if len(req.Parts) > 0 {
parts := []api.Part{}
for _, part := range req.Parts {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
parts = append(parts, p)
}
updatePayload.Parts = datatypes.NewJSONSlice(parts)
updatePayload.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
}
if req.Name.Value != "" {
updatePayload.Name = req.Name.Value
}
if req.ParentId.Value != "" {
updatePayload.ParentID = sql.NullString{
String: req.ParentId.Value,
Valid: true,
}
updatePayload.ParentId = utils.Ptr(req.ParentId.Value)
}
updatePayload.UpdatedAt = req.UpdatedAt
@ -592,28 +553,23 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
return &apiError{err: err}
}
if len(file.Parts) > 0 && file.ChannelID != nil {
_, session := auth.GetUser(ctx)
ids := []int{}
keys := []string{cache.Key("files", params.ID)}
if len(file.Parts) > 0 && file.ChannelId != nil {
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
tgc.DeleteMessages(ctx, client, *file.ChannelId, ids)
keys = append(keys, cache.Key("files", "messages", params.ID))
for _, part := range file.Parts {
ids = append(ids, int(part.ID))
keys = append(keys, cache.Key("files", "location", params.ID, part.ID))
}
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
tgc.DeleteMessages(ctx, client, *file.ChannelID, ids)
keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)}
for _, part := range file.Parts {
keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID))
}
a.cache.Delete(keys...)
}
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
a.cache.Delete(keys...)
return nil
}
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileID string, userId int64) {
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileId string, userId int64) {
ctx := r.Context()
var (
session *models.Session
@ -646,19 +602,17 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
session = &models.Session{UserId: userId}
}
file := &api.File{}
key := fmt.Sprintf("files:%s", fileID)
err = e.api.cache.Get(key, file)
file, err := cache.Fetch(e.api.cache, cache.Key("files", fileId), 0, func() (*models.File, error) {
var result models.File
if err := e.api.db.Model(&result).Where("id = ?", fileId).First(&result).Error; err != nil {
return nil, err
}
return &result, nil
})
if err != nil {
file, err = e.api.FilesGetById(ctx, api.FilesGetByIdParams{ID: fileID})
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
e.api.cache.Set(key, file, 0)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Accept-Ranges", "bytes")
@ -666,16 +620,15 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
var start, end int64
rangeHeader := r.Header.Get("Range")
contentType := defaultContentType
if file.Size.Value == 0 {
w.Header().Set("Content-Type", file.MimeType.Or(defaultContentType))
if file.MimeType != "" {
contentType = file.MimeType
}
if file.Size == nil || *file.Size == 0 {
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", "0")
if rangeHeader != "" {
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
http.Error(w, "Requested Range Not Satisfiable", http.StatusRequestedRangeNotSatisfiable)
return
}
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": file.Name}))
w.WriteHeader(http.StatusOK)
return
@ -684,11 +637,11 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
status := http.StatusOK
if rangeHeader == "" {
start = 0
end = file.Size.Value - 1
end = *file.Size - 1
} else {
ranges, err := http_range.Parse(rangeHeader, file.Size.Value)
ranges, err := http_range.Parse(rangeHeader, *file.Size)
if err == http_range.ErrNoOverlap {
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", *file.Size))
http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable)
return
}
@ -702,20 +655,18 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
}
start = ranges[0].Start
end = ranges[0].End
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size.Value))
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, *file.Size))
status = http.StatusPartialContent
}
contentLength := end - start + 1
mimeType := file.MimeType.Or(defaultContentType)
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileID+strconv.FormatInt(file.Size.Value, 10))))
w.Header().Set("Last-Modified", file.UpdatedAt.Value.UTC().Format(http.TimeFormat))
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileId+strconv.FormatInt(*file.Size, 10))))
w.Header().Set("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat))
disposition := "inline"
@ -733,7 +684,7 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
return
}
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, file.ChannelId.Value)
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, *file.ChannelId)
if err != nil {
http.Error(w, "failed to get bots", http.StatusInternalServerError)
@ -761,9 +712,9 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
multiThreads = 0
} else {
e.api.worker.Set(tokens, file.ChannelId.Value)
e.api.worker.Set(tokens, *file.ChannelId)
token, _ = e.api.worker.Next(file.ChannelId.Value)
token, _ = e.api.worker.Next(*file.ChannelId)
client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...)
if err != nil {
@ -816,5 +767,16 @@ func (e *extendedService) SharesStream(w http.ResponseWriter, r *http.Request, s
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
e.FilesStream(w, r, fileId, share.UserID)
e.FilesStream(w, r, fileId, share.UserId)
}
func mapParts(_parts []api.Part) []api.Part {
return utils.Map(_parts, func(part api.Part) api.Part {
p := api.Part{ID: part.ID}
if part.Salt.Value != "" {
p.Salt = part.Salt
}
return p
})
}

View file

@ -17,8 +17,7 @@ import (
)
type fileQueryBuilder struct {
db *gorm.DB
selectAllFields bool
db *gorm.DB
}
type fileResponse struct {
@ -26,7 +25,7 @@ type fileResponse struct {
Total int
}
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "encrypted", "size", "parent_id", "updated_at"}
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "channel_id", "encrypted", "size", "parent_id", "updated_at"}
func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int64) (*api.FileList, error) {
query := afb.db.Where("user_id = ?", userId).Where("status = ?", filesQuery.Status.Value)
@ -50,11 +49,7 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int
count = res[0].Total
}
files := []api.File{}
for _, file := range res {
files = append(files, *mapper.ToFileOut(file.File, afb.selectAllFields))
}
files := utils.Map(res, func(item fileResponse) api.File { return *mapper.ToFileOut(item.File) })
return &api.FileList{Items: files,
Meta: api.FileListMeta{Count: count,
@ -192,15 +187,10 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File
orderField := utils.CamelToSnake(string(filesQuery.Sort.Value))
op := getOrderOperation(filesQuery)
fields := selectedFields
if afb.selectAllFields {
fields = append(fields, "parts", "channel_id")
}
return afb.buildSubqueryCTE(query, filesQuery, userId).Clauses(exclause.NewWith("ranked_scores", afb.db.Model(&models.File{}).Select(orderField, "count(*) OVER () as total",
fmt.Sprintf("ROW_NUMBER() OVER (ORDER BY %s %s) AS rank", orderField, strings.ToUpper(string(filesQuery.Order.Value)))).
Where(query))).Model(&models.File{}).
Select(fields, "(select total from ranked_scores limit 1) as total").
Select(selectedFields, "(select total from ranked_scores limit 1) as total").
Where(fmt.Sprintf("%s %s (SELECT %s FROM ranked_scores WHERE rank = ?)", orderField, op, orderField),
max((filesQuery.Page.Value-1)*filesQuery.Limit.Value, 1)).
Where(query).Order(getOrder(filesQuery)).Limit(filesQuery.Limit.Value)
@ -223,13 +213,6 @@ func getOrder(filesQuery *api.FilesListParams) string {
return fmt.Sprintf("%s %s", orderField, strings.ToUpper(string(filesQuery.Order.Value)))
}
func max(x int, y int) int {
if x > y {
return x
}
return y
}
func getOrderOperation(filesQuery *api.FilesListParams) string {
if filesQuery.Page.Value == 1 {
if filesQuery.Order.Value == api.FileQueryOrderAsc {

View file

@ -60,7 +60,7 @@ func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetById
}
res := &api.FileShareInfo{
Protected: share.Password != nil,
UserId: share.UserID,
UserId: share.UserId,
Type: share.Type,
Name: share.Name,
}
@ -104,16 +104,16 @@ func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListF
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserID)
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId)
} else {
var file models.File
if err := a.db.Where("id = ?", share.FileID).First(&file).Error; err != nil {
if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil {
if database.IsRecordNotFoundErr(err) {
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
}
return nil, &apiError{err: err}
}
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file, false)},
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)},
Meta: api.FileListMeta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
}

View file

@ -48,7 +48,7 @@ func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPar
}
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var stats []api.UploadStats
err := a.db.Raw(`
SELECT
@ -94,7 +94,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
}
userId, session := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
fileStream := req.Content.Data
@ -116,7 +116,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
}
if len(tokens) == 0 {
client, err = tgc.AuthClient(ctx, &a.cnf.TG, session)
client, err = tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession)
if err != nil {
return nil, err
}
@ -220,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
Name: params.PartName,
UploadId: params.ID,
PartId: message.ID,
ChannelID: channelId,
ChannelId: channelId,
Size: fileSize,
PartNo: int(params.PartNo),
UserId: userId,
@ -244,7 +244,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
out = api.UploadPart{
Name: partUpload.Name,
PartId: partUpload.PartId,
ChannelId: partUpload.ChannelID,
ChannelId: partUpload.ChannelId,
PartNo: partUpload.PartNo,
Size: partUpload.Size,
Encrypted: partUpload.Encrypted,

View file

@ -16,6 +16,7 @@ import (
"github.com/gotd/td/tgerr"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
@ -27,8 +28,8 @@ import (
)
func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
userId, session := auth.GetUser(ctx)
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
userId := auth.GetUser(ctx)
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
if len(req.Bots) > 0 {
channelId, err := getDefaultChannel(a.db, a.cache, userId)
@ -47,11 +48,11 @@ func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, error) {
userID, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
channels := make(map[int64]*api.Channel)
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userID)))
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
iter, err := peerStorage.Iterate(ctx)
if err != nil {
@ -67,6 +68,7 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
}
}
res := []api.Channel{}
for _, channel := range channels {
res = append(res, *channel)
@ -79,10 +81,10 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
}
func (a *apiService) UsersSyncChannels(ctx context.Context) error {
userId, session := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
collector := storage.CollectPeers(peerStorage)
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
if err != nil {
return &apiError{err: err}
}
@ -96,7 +98,9 @@ func (a *apiService) UsersSyncChannels(ctx context.Context) error {
}
func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) {
userId, userSession := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
userSession := auth.GetJWTUser(ctx).TgSession
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
@ -150,9 +154,8 @@ func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession,
}
func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) {
_, session := auth.GetUser(ctx)
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
if err != nil {
return nil, &apiError{err: err}
@ -194,25 +197,25 @@ func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProf
}
func (a *apiService) UsersRemoveBots(ctx context.Context) error {
userID, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
channelId, err := getDefaultChannel(a.db, a.cache, userID)
channelId, err := getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return &apiError{err: err}
}
if err := a.db.Where("user_id = ?", userID).Where("channel_id = ?", channelId).
if err := a.db.Where("user_id = ?", userId).Where("channel_id = ?", channelId).
Delete(&models.Bot{}).Error; err != nil {
return &apiError{err: err}
}
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId))
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
return nil
}
func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRemoveSessionParams) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
session := &models.Session{}
@ -236,15 +239,18 @@ func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRem
}
func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
userID, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
var (
channelId int64
err error
)
channelId, _ = getDefaultChannel(a.db, a.cache, userID)
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, &apiError{err: err}
}
tokens, err := getBotsToken(a.db, a.cache, userID, channelId)
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
if err != nil {
return nil, &apiError{err: err}
@ -253,12 +259,12 @@ func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
}
func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpdate) error {
userId, _ := auth.GetUser(ctx)
userId := auth.GetUser(ctx)
channel := &models.Channel{UserID: userId, Selected: true}
channel := &models.Channel{UserId: userId, Selected: true}
if req.ChannelId.Value != 0 {
channel.ChannelID = req.ChannelId.Value
channel.ChannelId = req.ChannelId.Value
}
if req.ChannelName.Value != "" {
channel.ChannelName = req.ChannelName.Value
@ -270,11 +276,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd
}).Create(channel).Error; err != nil {
return &apiError{err: errors.New("failed to update channel")}
}
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelID).
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelId).
Where("user_id = ?", userId).Update("selected", false)
key := fmt.Sprintf("users:channel:%d", userId)
a.cache.Set(key, channel.ChannelID, 0)
a.cache.Set(cache.Key("users", "channel", userId), channel.ChannelId, 0)
return nil
}
@ -359,12 +364,12 @@ func (a *apiService) addBots(c context.Context, client *telegram.Client, userId
payload := []models.Bot{}
for _, info := range botInfoMap {
payload = append(payload, models.Bot{UserID: userId, Token: info.Token, BotID: info.Id,
BotUserName: info.UserName, ChannelID: channelId,
payload = append(payload, models.Bot{UserId: userId, Token: info.Token, BotId: info.Id,
BotUserName: info.UserName, ChannelId: channelId,
})
}
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId))
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
return err