diff --git a/go.mod b/go.mod index 116c0ae..283605a 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,12 @@ require ( github.com/go-chi/chi/v5 v5.2.0 github.com/go-chi/cors v1.2.1 github.com/go-co-op/gocron v1.37.0 + github.com/go-viper/mapstructure/v2 v2.2.1 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/gotd/contrib v0.21.0 github.com/gotd/td v0.117.0 github.com/iyear/connectproxy v0.1.1 - github.com/mitchellh/go-homedir v1.1.0 - github.com/mitchellh/mapstructure v1.5.0 github.com/ogen-go/ogen v1.8.1 github.com/redis/go-redis/v9 v9.7.0 github.com/spf13/cobra v1.8.1 @@ -50,6 +49,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mfridman/interpolate v0.0.2 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect diff --git a/go.sum b/go.sum index 134b84d..d0e5e48 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= +github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= @@ -148,8 +150,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw= github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo= -github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index fba943b..5416497 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -42,10 +42,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) { } -func GetUser(c context.Context) (int64, string) { +func GetUser(c context.Context) int64 { authUser, _ := c.Value(authKey).(*types.JWTClaims) userId, _ := strconv.ParseInt(authUser.Subject, 10, 64) - return userId, authUser.TgSession + return userId } func GetJWTUser(c context.Context) *types.JWTClaims { diff --git a/internal/cache/cache.go b/internal/cache/cache.go index d5f54a1..f1d9220 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -2,6 +2,10 @@ package cache import ( "context" + "errors" + "fmt" + "reflect" + "strings" "sync" "time" @@ -29,8 +33,16 @@ func NewCache(ctx context.Context, conf *config.CacheConfig) Cacher { cacher = NewMemoryCache(conf.MaxSize) } else { cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{ - Addr: conf.RedisAddr, - Password: conf.RedisPass, + Addr: conf.RedisAddr, + Password: conf.RedisPass, + DialTimeout: 5 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + PoolSize: 10, + MinIdleConns: 5, + MaxIdleConns: 10, + ConnMaxIdleTime: 5 * time.Minute, + ConnMaxLifetime: 1 * time.Hour, })) } return cacher @@ -119,3 +131,61 @@ func (r *RedisCache) Delete(keys ...string) error { } return r.client.Del(r.ctx, keys...).Err() } + +func Fetch[T any](cache Cacher, key string, expiration time.Duration, fn func() (T, error)) (T, error) { + var zero, value T + err := cache.Get(key, &value) + if err != nil { + if errors.Is(err, freecache.ErrNotFound) || errors.Is(err, redis.Nil) { + value, err = fn() + if err != nil { + return zero, err + } + cache.Set(key, &value, expiration) + return value, nil + } + return zero, err + } + return value, nil +} + +func Key(args ...interface{}) string { + parts := make([]string, len(args)) + for i, arg := range args { + parts[i] = formatValue(arg) + } + return strings.Join(parts, ":") +} + +func formatValue(v interface{}) string { + if v == nil { + return "nil" + } + + val := reflect.ValueOf(v) + switch val.Kind() { + case reflect.Ptr: + if val.IsNil() { + return "nil" + } + return formatValue(val.Elem().Interface()) + case reflect.Array, reflect.Slice: + parts := make([]string, val.Len()) + for i := 0; i < val.Len(); i++ { + parts[i] = formatValue(val.Index(i).Interface()) + } + return fmt.Sprintf("[%s]", strings.Join(parts, ",")) + case reflect.Map: + parts := make([]string, 0, val.Len()) + for _, key := range val.MapKeys() { + k := formatValue(key.Interface()) + v := formatValue(val.MapIndex(key).Interface()) + parts = append(parts, fmt.Sprintf("%s=%s", k, v)) + } + return fmt.Sprintf("{%s}", strings.Join(parts, ",")) + case reflect.Struct: + return fmt.Sprintf("%+v", v) + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index ddaca5e..4da8dbb 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -5,16 +5,16 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/tgdrive/teldrive/internal/api" + "github.com/tgdrive/teldrive/pkg/models" ) func TestCache(t *testing.T) { - var value = api.File{ + var value = models.File{ Name: "file.jpeg", Type: "file", } - var result api.File + var result models.File cache := NewMemoryCache(1 * 1024 * 1024) @@ -25,3 +25,149 @@ func TestCache(t *testing.T) { assert.NoError(t, err) assert.Equal(t, result, value) } + +func TestKey(t *testing.T) { + tests := []struct { + name string + args []interface{} + expected string + }{ + { + name: "simple strings", + args: []interface{}{"user", "123"}, + expected: "user:123", + }, + { + name: "mixed types", + args: []interface{}{"cache", 123, true}, + expected: "cache:123:true", + }, + { + name: "with nil", + args: []interface{}{"key", nil, "value"}, + expected: "key:nil:value", + }, + { + name: "empty args", + args: []interface{}{}, + expected: "", + }, + { + name: "single arg", + args: []interface{}{"solo"}, + expected: "solo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Key(tt.args...) + if result != tt.expected { + t.Errorf("Key() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestFormatValue(t *testing.T) { + type testStruct struct { + Name string + Age int + } + + tests := []struct { + name string + input interface{} + expected string + }{ + { + name: "nil value", + input: nil, + expected: "nil", + }, + { + name: "string", + input: "test", + expected: "test", + }, + { + name: "integer", + input: 123, + expected: "123", + }, + { + name: "boolean", + input: true, + expected: "true", + }, + { + name: "slice of strings", + input: []string{"a", "b", "c"}, + expected: "[a,b,c]", + }, + { + name: "slice of ints", + input: []int{1, 2, 3}, + expected: "[1,2,3]", + }, + { + name: "empty slice", + input: []string{}, + expected: "[]", + }, + { + name: "map string to string", + input: map[string]string{"a": "1", "b": "2"}, + expected: "{a=1,b=2}", + }, + { + name: "empty map", + input: map[string]string{}, + expected: "{}", + }, + { + name: "struct", + input: testStruct{Name: "John", Age: 30}, + expected: "{Name:John Age:30}", + }, + { + name: "pointer to string", + input: func() interface{} { s := "test"; return &s }(), + expected: "test", + }, + { + name: "nil pointer", + input: func() interface{} { var s *string; return s }(), + expected: "nil", + }, + { + name: "nested slice", + input: [][]int{{1, 2}, {3, 4}}, + expected: "[[1,2],[3,4]]", + }, + { + name: "complex mixed structure", + input: struct { + ID int + Tags []string + Meta map[string]interface{} + Valid bool + }{ + ID: 1, + Tags: []string{"a", "b"}, + Meta: map[string]interface{}{"count": 42}, + Valid: true, + }, + expected: "{ID:1 Tags:[a b] Meta:map[count:42] Valid:true}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatValue(tt.input) + if result != tt.expected { + t.Errorf("formatValue() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index b1c1405..7f59210 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,13 +2,13 @@ package config import ( "fmt" + "os" "path/filepath" "reflect" "strings" "time" - "github.com/mitchellh/go-homedir" - "github.com/mitchellh/mapstructure" + "github.com/go-viper/mapstructure/v2" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -144,7 +144,7 @@ func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error { if cfgFile != "" { cl.v.SetConfigFile(cfgFile) } else { - home, err := homedir.Dir() + home, err := os.UserHomeDir() if err != nil { return fmt.Errorf("error getting home directory: %v", err) } diff --git a/internal/reader/reader.go b/internal/reader/reader.go index 9421e7e..f15b2aa 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -2,14 +2,13 @@ package reader import ( "context" - "fmt" "io" "github.com/gotd/td/tg" - "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/config" "github.com/tgdrive/teldrive/internal/crypt" + "github.com/tgdrive/teldrive/pkg/models" "github.com/tgdrive/teldrive/pkg/types" ) @@ -20,7 +19,7 @@ type Range struct { type LinearReader struct { ctx context.Context - file *api.File + file *models.File parts []types.Part ranges []Range pos int @@ -52,7 +51,7 @@ func calculatePartByteRanges(start, end, partSize int64) []Range { func NewLinearReader(ctx context.Context, client *tg.Client, cache cache.Cacher, - file *api.File, + file *models.File, parts []types.Part, start, end int64, @@ -61,7 +60,7 @@ func NewLinearReader(ctx context.Context, ) (io.ReadCloser, error) { size := parts[0].Size - if file.Encrypted.Value { + if file.Encrypted { size = parts[0].DecryptedSize } r := &LinearReader{ @@ -129,22 +128,22 @@ func (r *LinearReader) moveToNextPart() error { func (r *LinearReader) getPartReader() (io.ReadCloser, error) { currentRange := r.ranges[r.pos] - partID := r.parts[currentRange.PartNo].ID + partId := r.parts[currentRange.PartNo].ID chunkSrc := &chunkSource{ - channelID: r.file.ChannelId.Value, - partID: partID, + channelId: *r.file.ChannelId, + partId: partId, client: r.client, concurrency: r.concurrency, cache: r.cache, - key: fmt.Sprintf("files:location:%s:%d", r.file.ID.Value, partID), + key: cache.Key("files", "location", r.file.ID, partId), } var ( reader io.ReadCloser err error ) - if r.file.Encrypted.Value { + if r.file.Encrypted { salt := r.parts[r.ranges[r.pos].PartNo].Salt cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt) reader, err = cipher.DecryptDataSeek(r.ctx, diff --git a/internal/reader/tg_multi_reader.go b/internal/reader/tg_multi_reader.go index 7856159..60f0ea3 100644 --- a/internal/reader/tg_multi_reader.go +++ b/internal/reader/tg_multi_reader.go @@ -25,8 +25,8 @@ type ChunkSource interface { } type chunkSource struct { - channelID int64 - partID int64 + channelId int64 + partId int64 concurrency int client *tg.Client key string @@ -46,7 +46,7 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b err = c.cache.Get(c.key, location) if err != nil { - location, err = tgc.GetLocation(ctx, c.client, c.channelID, c.partID) + location, err = tgc.GetLocation(ctx, c.client, c.channelId, c.partId) if err != nil { return nil, err } diff --git a/internal/tgc/bolt.go b/internal/tgc/bolt.go index 7e2ba07..67b33a0 100644 --- a/internal/tgc/bolt.go +++ b/internal/tgc/bolt.go @@ -5,14 +5,13 @@ import ( "path/filepath" "time" - "github.com/mitchellh/go-homedir" "github.com/tgdrive/teldrive/internal/utils" "go.etcd.io/bbolt" ) func NewBoltDB(sessionFile string) (*bbolt.DB, error) { if sessionFile == "" { - dir, err := homedir.Dir() + dir, err := os.UserHomeDir() if err != nil { dir = utils.ExecutableDir() } else { diff --git a/internal/tgc/helpers.go b/internal/tgc/helpers.go index 3acc3f6..66d9f15 100644 --- a/internal/tgc/helpers.go +++ b/internal/tgc/helpers.go @@ -12,13 +12,14 @@ import ( "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/tgdrive/teldrive/internal/config" + "github.com/tgdrive/teldrive/internal/utils" "github.com/tgdrive/teldrive/pkg/types" "go.etcd.io/bbolt" "golang.org/x/sync/errgroup" ) var ( - ErrInValidChannelID = errors.New("invalid channel id") + ErrInValidChannelId = errors.New("invalid channel id") ErrInvalidChannelMessages = errors.New("invalid channel messages") ) @@ -33,7 +34,7 @@ func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*t } if len(channels.GetChats()) == 0 { - return nil, ErrInValidChannelID + return nil, ErrInValidChannelId } return channels.GetChats()[0].(*tg.Channel).AsInput(), nil } @@ -71,15 +72,11 @@ func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int6 func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) { - msgIds := []tg.InputMessageClass{} - - for _, id := range ids { - msgIds = append(msgIds, &tg.InputMessageID{ID: id}) - } - messageRequest := tg.ChannelsGetMessagesRequest{ Channel: channel, - ID: msgIds, + ID: utils.Map(ids, func(id int) tg.InputMessageClass { + return &tg.InputMessageID{ID: id} + }), } res, err := client.ChannelsGetMessages(ctx, &messageRequest) diff --git a/internal/tgc/workers.go b/internal/tgc/workers.go index ffd8f46..f92f175 100644 --- a/internal/tgc/workers.go +++ b/internal/tgc/workers.go @@ -17,21 +17,21 @@ func NewBotWorker() *BotWorker { } } -func (w *BotWorker) Set(bots []string, channelID int64) { +func (w *BotWorker) Set(bots []string, channelId int64) { w.mu.Lock() defer w.mu.Unlock() - if _, ok := w.bots[channelID]; ok { + if _, ok := w.bots[channelId]; ok { return } - w.bots[channelID] = bots - w.currIdx[channelID] = 0 + w.bots[channelId] = bots + w.currIdx[channelId] = 0 } -func (w *BotWorker) Next(channelID int64) (string, int) { +func (w *BotWorker) Next(channelId int64) (string, int) { w.mu.RLock() defer w.mu.RUnlock() - bots := w.bots[channelID] - index := w.currIdx[channelID] - w.currIdx[channelID] = (index + 1) % len(bots) + bots := w.bots[channelId] + index := w.currIdx[channelId] + w.currIdx[channelId] = (index + 1) % len(bots) return bots[index], index } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 64e9e97..7f0c0e0 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -5,73 +5,18 @@ import ( "path/filepath" "regexp" "strings" - "time" - - "reflect" - - "unicode" ) -func CamelToPascalCase(input string) string { - var result strings.Builder - upperNext := true - - for _, char := range input { - if unicode.IsLetter(char) || unicode.IsDigit(char) { - if upperNext { - result.WriteRune(unicode.ToUpper(char)) - upperNext = false - } else { - result.WriteRune(char) - } - } else { - upperNext = true - } - } - - return result.String() -} - func CamelToSnake(input string) string { re := regexp.MustCompile("([a-z0-9])([A-Z])") snake := re.ReplaceAllString(input, "${1}_${2}") return strings.ToLower(snake) } -func GetField(v interface{}, field string) string { - r := reflect.ValueOf(v) - f := reflect.Indirect(r).FieldByName(field) - fieldValue := f.Interface() - - switch v := fieldValue.(type) { - case string: - return v - case time.Time: - return v.Format(time.RFC3339) - default: - return "" - } -} - func Ptr[T any](t T) *T { return &t } -func BoolPointer(b bool) *bool { - return &b -} - -func IntPointer(b int) *int { - return &b -} -func Int64Pointer(b int64) *int64 { - return &b -} - -func StringPointer(b string) *string { - return &b -} - func PathExists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { @@ -84,8 +29,34 @@ func PathExists(path string) (bool, error) { } func ExecutableDir() string { - path, _ := os.Executable() - return filepath.Dir(path) } + +func Filter[T any](slice []T, predicate func(T) bool) []T { + var result []T + for _, v := range slice { + if predicate(v) { + result = append(result, v) + } + } + return result +} + +func Map[T any, U any](slice []T, mapper func(T) U) []U { + var result []U + for _, v := range slice { + result = append(result, mapper(v)) + } + return result +} + +func Find[T any](slice []T, predicate func(T) bool) (T, bool) { + for _, v := range slice { + if predicate(v) { + return v, true + } + } + var zero T + return zero, false +} diff --git a/pkg/mapper/mapper.go b/pkg/mapper/mapper.go index f8df2e0..d726e5b 100644 --- a/pkg/mapper/mapper.go +++ b/pkg/mapper/mapper.go @@ -2,47 +2,41 @@ package mapper import ( "github.com/tgdrive/teldrive/internal/api" + "github.com/tgdrive/teldrive/internal/utils" "github.com/tgdrive/teldrive/pkg/models" ) -func ToFileOut(file models.File, extended bool) *api.File { +func ToFileOut(file models.File) *api.File { res := &api.File{ - ID: api.NewOptString(file.Id), + ID: api.NewOptString(file.ID), Name: file.Name, Type: api.FileType(file.Type), MimeType: api.NewOptString(file.MimeType), Encrypted: api.NewOptBool(file.Encrypted), - ParentId: api.NewOptString(file.ParentID.String), UpdatedAt: api.NewOptDateTime(file.UpdatedAt), } + if file.ParentId != nil { + res.ParentId = api.NewOptString(*file.ParentId) + } if file.Size != nil { res.Size = api.NewOptInt64(*file.Size) } if file.Category != "" { res.Category = api.NewOptFileCategory(api.FileCategory(file.Category)) } - if extended { - res.Parts = file.Parts - if file.ChannelID != nil { - res.ChannelId = api.NewOptInt64(*file.ChannelID) - } - } return res } func ToUploadOut(parts []models.Upload) []api.UploadPart { - res := []api.UploadPart{} - for _, part := range parts { - res = append(res, api.UploadPart{ + return utils.Map(parts, func(part models.Upload) api.UploadPart { + return api.UploadPart{ Name: part.Name, PartId: part.PartId, - ChannelId: part.ChannelID, + ChannelId: part.ChannelId, PartNo: part.PartNo, Size: part.Size, Encrypted: part.Encrypted, Salt: api.NewOptString(part.Salt), - }) - - } - return res + } + }) } diff --git a/pkg/models/bot.go b/pkg/models/bot.go index b8a1a76..383ac88 100644 --- a/pkg/models/bot.go +++ b/pkg/models/bot.go @@ -2,8 +2,8 @@ package models type Bot struct { Token string `gorm:"type:text;primaryKey"` - UserID int64 `gorm:"type:bigint"` - BotID int64 `gorm:"type:bigint"` + UserId int64 `gorm:"type:bigint"` + BotId int64 `gorm:"type:bigint"` BotUserName string `gorm:"type:text"` - ChannelID int64 `gorm:"type:bigint"` + ChannelId int64 `gorm:"type:bigint"` } diff --git a/pkg/models/channel.go b/pkg/models/channel.go index 8bb2935..d0cecfe 100644 --- a/pkg/models/channel.go +++ b/pkg/models/channel.go @@ -1,8 +1,8 @@ package models type Channel struct { - ChannelID int64 `gorm:"type:bigint;primaryKey"` + ChannelId int64 `gorm:"type:bigint;primaryKey"` ChannelName string `gorm:"type:text"` - UserID int64 `gorm:"type:bigint;"` + UserId int64 `gorm:"type:bigint;"` Selected bool `gorm:"type:boolean;"` } diff --git a/pkg/models/file.go b/pkg/models/file.go index 7078cbc..a170608 100644 --- a/pkg/models/file.go +++ b/pkg/models/file.go @@ -1,7 +1,6 @@ package models import ( - "database/sql" "time" "github.com/tgdrive/teldrive/internal/api" @@ -9,18 +8,18 @@ import ( ) type File struct { - Id string `gorm:"type:uuid;primaryKey;default:uuid7()"` + ID string `gorm:"type:uuid;primaryKey;default:uuid7()"` Name string `gorm:"type:text;not null"` Type string `gorm:"type:text;not null"` MimeType string `gorm:"type:text;not null"` Size *int64 `gorm:"type:bigint"` Category string `gorm:"type:text"` Encrypted bool `gorm:"default:false"` - UserID int64 `gorm:"type:bigint;not null"` + UserId int64 `gorm:"type:bigint;not null"` Status string `gorm:"type:text"` - ParentID sql.NullString `gorm:"type:uuid;index"` + ParentId *string `gorm:"type:uuid;index"` Parts datatypes.JSONSlice[api.Part] `gorm:"type:jsonb"` - ChannelID *int64 `gorm:"type:bigint"` + ChannelId *int64 `gorm:"type:bigint"` CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"` UpdatedAt time.Time `gorm:"autoUpdateTime:false"` } diff --git a/pkg/models/share.go b/pkg/models/share.go index 5d71ceb..636b6d4 100644 --- a/pkg/models/share.go +++ b/pkg/models/share.go @@ -6,10 +6,10 @@ import ( type FileShare struct { ID string `gorm:"type:uuid;default:uuid_generate_v4();primary_key"` - FileID string `gorm:"type:uuid;not null"` + FileId string `gorm:"type:uuid;not null"` Password *string `gorm:"type:text"` ExpiresAt *time.Time `gorm:"type:timestamp"` CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"` UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"` - UserID int64 `gorm:"type:bigint;not null"` + UserId int64 `gorm:"type:bigint;not null"` } diff --git a/pkg/models/upload.go b/pkg/models/upload.go index 2928c74..697575a 100644 --- a/pkg/models/upload.go +++ b/pkg/models/upload.go @@ -12,7 +12,7 @@ type Upload struct { PartId int `gorm:"type:integer"` Encrypted bool `gorm:"default:false"` Salt string `gorm:"type:text"` - ChannelID int64 `gorm:"type:bigint"` + ChannelId int64 `gorm:"type:bigint"` Size int64 `gorm:"type:bigint"` CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"` } diff --git a/pkg/services/api_service.go b/pkg/services/api.go similarity index 100% rename from pkg/services/api_service.go rename to pkg/services/api.go diff --git a/pkg/services/auth.go b/pkg/services/auth.go index 418c8c2..6c0e7d1 100644 --- a/pkg/services/auth.go +++ b/pkg/services/auth.go @@ -25,6 +25,7 @@ import ( "github.com/gotd/td/tgerr" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" + "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/tgc" "github.com/tgdrive/teldrive/pkg/models" @@ -80,7 +81,7 @@ func (a *apiService) AuthLogin(ctx context.Context, session *api.SessionCreate) Name: "root", Type: "folder", MimeType: "drive/folder", - UserID: session.UserId, + UserId: session.UserId, Status: "active", Parts: nil, } @@ -129,7 +130,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent, return err }) a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{}) - a.cache.Delete(fmt.Sprintf("sessions:%s", authUser.Hash)) + a.cache.Delete(cache.Key("sessions", authUser.Hash)) return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil } @@ -360,7 +361,7 @@ func pack32BinaryIP4(ip4Address string) []byte { return buf.Bytes() } -func generateTgSession(dcID int, authKey []byte, port int) string { +func generateTgSession(dcId int, authKey []byte, port int) string { dcMaps := map[int]string{ 1: "149.154.175.53", @@ -370,8 +371,8 @@ func generateTgSession(dcID int, authKey []byte, port int) string { 5: "91.108.56.130", } - dcIDByte := byte(dcID) - serverAddressBytes := pack32BinaryIP4(dcMaps[dcID]) + dcIDByte := byte(dcId) + serverAddressBytes := pack32BinaryIP4(dcMaps[dcId]) portByte := make([]byte, 2) binary.BigEndian.PutUint16(portByte, uint16(port)) diff --git a/pkg/services/common.go b/pkg/services/common.go index 1118ac2..9be6230 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -2,106 +2,88 @@ package services import ( "context" + "errors" "fmt" "time" - "github.com/go-faster/errors" "github.com/gotd/td/telegram" "github.com/gotd/td/tg" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/crypt" + "github.com/tgdrive/teldrive/internal/logging" "github.com/tgdrive/teldrive/internal/tgc" + "github.com/tgdrive/teldrive/internal/utils" "github.com/tgdrive/teldrive/pkg/models" "github.com/tgdrive/teldrive/pkg/types" + "go.uber.org/zap" "gorm.io/gorm" ) -func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *api.File) ([]types.Part, error) { +func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) { + return cache.Fetch(c, cache.Key("files", "messages", file.ID), 60*time.Minute, func() ([]types.Part, error) { + messages, err := tgc.GetMessages(ctx, client.API(), utils.Map(file.Parts, func(part api.Part) int { + return part.ID + }), *file.ChannelId) - parts := []types.Part{} - - key := fmt.Sprintf("files:messages:%s", file.ID.Value) - - err := cache.Get(key, &parts) - - if err == nil { + if err != nil { + return nil, err + } + parts := []types.Part{} + for i, message := range messages { + switch item := message.(type) { + case *tg.Message: + media, ok := item.Media.(*tg.MessageMediaDocument) + if !ok { + continue + } + document, ok := media.Document.(*tg.Document) + if !ok { + continue + } + part := types.Part{ + ID: int64(file.Parts[i].ID), + Size: document.Size, + Salt: file.Parts[i].Salt.Value, + } + if file.Encrypted { + part.DecryptedSize, _ = crypt.DecryptedSize(document.Size) + } + parts = append(parts, part) + } + } + if len(parts) != len(file.Parts) { + msg := "file parts mismatch" + logging.FromContext(ctx).Error(msg, zap.String("name", file.Name), + zap.Int("expected", len(file.Parts)), zap.Int("actual", len(parts))) + return nil, errors.New(msg) + } return parts, nil - } - - ids := []int{} - for _, part := range file.Parts { - ids = append(ids, int(part.ID)) - } - messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value) - - if err != nil { - return nil, err - } - - for i, message := range messages { - item := message.(*tg.Message) - media := item.Media.(*tg.MessageMediaDocument) - document := media.Document.(*tg.Document) - - part := types.Part{ - ID: int64(file.Parts[i].ID), - Size: document.Size, - Salt: file.Parts[i].Salt.Value, - } - if file.Encrypted.Value { - part.DecryptedSize, _ = crypt.DecryptedSize(document.Size) - } - parts = append(parts, part) - } - cache.Set(key, &parts, 60*time.Minute) - return parts, nil + }) } -func getDefaultChannel(db *gorm.DB, cache cache.Cacher, userID int64) (int64, error) { - - var channelId int64 - key := fmt.Sprintf("users:channel:%d", userID) - - err := cache.Get(key, &channelId) - - if err == nil { - return channelId, nil - } - - var channelIds []int64 - db.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true). - Pluck("channel_id", &channelIds) - - if len(channelIds) == 1 { - channelId = channelIds[0] - cache.Set(key, channelId, 0) - } - - if channelId == 0 { - return channelId, errors.New("default channel not set") - } - - return channelId, nil +func getDefaultChannel(db *gorm.DB, c cache.Cacher, userId int64) (int64, error) { + return cache.Fetch(c, cache.Key("users", "channel", userId), 0, func() (int64, error) { + var channelIds []int64 + if err := db.Model(&models.Channel{}).Where("user_id = ?", userId).Where("selected = ?", true). + Pluck("channel_id", &channelIds).Error; err != nil { + return 0, err + } + if len(channelIds) == 0 { + return 0, fmt.Errorf("no default channel found for user %d", userId) + } + return channelIds[0], nil + }) } -func getBotsToken(db *gorm.DB, cache cache.Cacher, userID, channelId int64) ([]string, error) { - var bots []string - - key := fmt.Sprintf("users:bots:%d:%d", userID, channelId) - - err := cache.Get(key, &bots) - - if err == nil { +func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]string, error) { + return cache.Fetch(c, cache.Key("users", "bots", userId, channelId), 0, func() ([]string, error) { + var bots []string + if err := db.Model(&models.Bot{}).Where("user_id = ?", userId). + Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil { + return nil, err + } return bots, nil - } - - if err := db.Model(&models.Bot{}).Where("user_id = ?", userID). - Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil { - return nil, err - } - - cache.Set(key, &bots, 0) - return bots, nil + }) } diff --git a/pkg/services/file.go b/pkg/services/file.go index c6c60b8..1e1625d 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -3,7 +3,6 @@ package services import ( "context" "crypto/rand" - "database/sql" "encoding/binary" "errors" "fmt" @@ -20,6 +19,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" + "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/category" "github.com/tgdrive/teldrive/internal/database" "github.com/tgdrive/teldrive/internal/http_range" @@ -71,7 +71,7 @@ func randInt64() (int64, error) { b := &buffer{Buf: buf[:]} return b.long() } -func isUUID(str string) bool { +func isUUId(str string) bool { _, err := uuid.Parse(str) return err == nil } @@ -97,10 +97,10 @@ func (a *apiService) getFileFromPath(path string, userId int64) (*models.File, e } func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStats, error) { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var stats []api.CategoryStats if err := a.db.Model(&models.File{}).Select("category", "COUNT(*) as total_files", "coalesce(SUM(size),0) as total_size"). - Where(&models.File{UserID: userId, Type: "file", Status: "active"}). + Where(&models.File{UserId: userId, Type: "file", Status: "active"}). Order("category ASC").Group("category").Find(&stats).Error; err != nil { return nil, &apiError{err: err} } @@ -110,9 +110,9 @@ func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStat func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params api.FilesCopyParams) (*api.File, error) { - userId, session := auth.GetUser(ctx) + userId := auth.GetUser(ctx) - client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) + client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...) var res []models.File @@ -133,12 +133,9 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap } err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error { - ids := []int{} - for _, part := range file.Parts { - ids = append(ids, int(part.ID)) - } - messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID) + ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID }) + messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelId) if err != nil { return err @@ -198,13 +195,13 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap } var parentId string - if !isUUID(req.Destination) { + if !isUUId(req.Destination) { var destRes []models.File if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination). Scan(&destRes).Error; err != nil { return nil, &apiError{err: err} } - parentId = destRes[0].Id + parentId = destRes[0].ID } else { parentId = req.Destination } @@ -218,13 +215,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap if len(newIds) > 0 { dbFile.Parts = datatypes.NewJSONSlice(newIds) } - dbFile.UserID = userId + dbFile.UserId = userId dbFile.Status = "active" - dbFile.ParentID = sql.NullString{ - String: parentId, - Valid: true, - } - dbFile.ChannelID = &channelId + dbFile.ParentId = utils.Ptr(parentId) + dbFile.ChannelId = &channelId dbFile.Encrypted = file.Encrypted dbFile.Category = string(file.Category) if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() { @@ -237,11 +231,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap return nil, &apiError{err: err} } - return mapper.ToFileOut(dbFile, false), nil + return mapper.ToFileOut(dbFile), nil } func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.File, error) { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var ( fileDB models.File @@ -267,15 +261,9 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi if err != nil { return nil, &apiError{err: err, code: 404} } - fileDB.ParentID = sql.NullString{ - String: parent.Id, - Valid: true, - } + fileDB.ParentId = utils.Ptr(parent.ID) } else if fileIn.ParentId.Value != "" { - fileDB.ParentID = sql.NullString{ - String: fileIn.ParentId.Value, - Valid: true, - } + fileDB.ParentId = utils.Ptr(fileIn.ParentId.Value) } @@ -291,25 +279,17 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi } else { channelId = fileIn.ChannelId.Value } - fileDB.ChannelID = &channelId + fileDB.ChannelId = &channelId fileDB.MimeType = fileIn.MimeType.Value fileDB.Category = string(category.GetCategory(fileIn.Name)) if len(fileIn.Parts) > 0 { - parts := []api.Part{} - for _, part := range fileIn.Parts { - p := api.Part{ID: part.ID} - if part.Salt.Value != "" { - p.Salt = part.Salt - } - parts = append(parts, p) - } - fileDB.Parts = datatypes.NewJSONSlice(parts) + fileDB.Parts = datatypes.NewJSONSlice(mapParts(fileIn.Parts)) } fileDB.Size = utils.Ptr(fileIn.Size.Or(0)) } fileDB.Name = fileIn.Name fileDB.Type = string(fileIn.Type) - fileDB.UserID = userId + fileDB.UserId = userId fileDB.Status = "active" fileDB.Encrypted = fileIn.Encrypted.Value if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() { @@ -323,11 +303,11 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi } return nil, &apiError{err: err} } - return mapper.ToFileOut(fileDB, false), nil + return mapper.ToFileOut(fileDB), nil } func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCreate, params api.FilesCreateShareParams) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var fileShare models.FileShare @@ -339,11 +319,11 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre fileShare.Password = utils.Ptr(string(bytes)) } - fileShare.FileID = params.ID + fileShare.FileId = params.ID if req.ExpiresAt.IsSet() { fileShare.ExpiresAt = utils.Ptr(req.ExpiresAt.Value) } - fileShare.UserID = userId + fileShare.UserId = userId if err := a.db.Create(&fileShare).Error; err != nil { return &apiError{err: err} @@ -353,7 +333,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre } func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil { return &apiError{err: err} } @@ -362,7 +342,7 @@ func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error } func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDeleteShareParams) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var deletedShare models.FileShare @@ -371,14 +351,14 @@ func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDelet return &apiError{err: err} } if deletedShare.ID != "" { - a.cache.Delete(fmt.Sprintf("shares:%s", deletedShare.ID)) + a.cache.Delete(cache.Key("shared", deletedShare.ID)) } return nil } func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreate, params api.FilesEditShareParams) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var fileShareUpdate models.FileShare @@ -387,7 +367,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat if err != nil { return &apiError{err: err} } - fileShareUpdate.Password = utils.StringPointer(string(bytes)) + fileShareUpdate.Password = utils.Ptr(string(bytes)) } if req.ExpiresAt.IsSet() { fileShareUpdate.ExpiresAt = utils.Ptr(req.ExpiresAt.Value) @@ -403,26 +383,25 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat func (a *apiService) FilesGetById(ctx context.Context, params api.FilesGetByIdParams) (*api.File, error) { var result []fullFileDB - notFoundResponse := &apiError{err: errors.New("file not found"), code: 404} if err := a.db.Model(&models.File{}).Select("*", "(select get_path_from_file_id as path from teldrive.get_path_from_file_id(id))"). Where("id = ?", params.ID).Scan(&result).Error; err != nil { - if database.IsRecordNotFoundErr(err) { - return nil, notFoundResponse - } return nil, &apiError{err: err} } if len(result) == 0 { - return nil, notFoundResponse + return nil, &apiError{err: errors.New("file not found"), code: 404} } - res := mapper.ToFileOut(result[0].File, true) + res := mapper.ToFileOut(result[0].File) res.Path = api.NewOptString(result[0].Path) + if result[0].ChannelId != nil { + res.ChannelId = api.NewOptInt64(*result[0].ChannelId) + } return res, nil } func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) (*api.FileList, error) { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) queryBuilder := &fileQueryBuilder{db: a.db} @@ -430,7 +409,7 @@ func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) } func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) if err := a.db.Exec("select * from teldrive.create_directories(?, ?)", userId, req.Path).Error; err != nil { return &apiError{err: err} @@ -439,18 +418,18 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error { } func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) items := pgtype.Array[string]{ Elements: req.Ids, Valid: true, Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}}, } - if !isUUID(req.Destination) { + if !isUUId(req.Destination) { r, err := a.getFileFromPath(req.Destination, userId) if err != nil { return &apiError{err: err} } - req.Destination = r.Id + req.Destination = r.ID } if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId). Update("parent_id", req.Destination).Error; err != nil { @@ -462,7 +441,7 @@ func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error { } func (a *apiService) FilesShareByid(ctx context.Context, params api.FilesShareByidParams) (*api.FileShare, error) { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var result []models.FileShare notFoundErr := &apiError{err: errors.New("invalid share"), code: 404} @@ -500,15 +479,7 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param updateDb.Name = req.Name.Value } if len(req.Parts) > 0 { - parts := []api.Part{} - for _, part := range req.Parts { - p := api.Part{ID: part.ID} - if part.Salt.Value != "" { - p.Salt = part.Salt - } - parts = append(parts, p) - } - updateDb.Parts = datatypes.NewJSONSlice(parts) + updateDb.Parts = datatypes.NewJSONSlice(mapParts(req.Parts)) } if req.Size.Value != 0 { updateDb.Size = utils.Ptr(req.Size.Value) @@ -523,17 +494,18 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param return nil, &apiError{err: err} } - a.cache.Delete(fmt.Sprintf("files:%s", params.ID)) + a.cache.Delete(cache.Key("files", params.ID)) file := models.File{} if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil { return nil, &apiError{err: err} } - return mapper.ToFileOut(file, false), nil + return mapper.ToFileOut(file), nil } func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error { - userId, _ := auth.GetUser(ctx) + + userId := auth.GetUser(ctx) var file models.File @@ -545,29 +517,18 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd if err != nil { return &apiError{err: err} } - updatePayload.ChannelID = &channelId + updatePayload.ChannelId = &channelId } else { - updatePayload.ChannelID = &req.ChannelId.Value + updatePayload.ChannelId = &req.ChannelId.Value } if len(req.Parts) > 0 { - parts := []api.Part{} - for _, part := range req.Parts { - p := api.Part{ID: part.ID} - if part.Salt.Value != "" { - p.Salt = part.Salt - } - parts = append(parts, p) - } - updatePayload.Parts = datatypes.NewJSONSlice(parts) + updatePayload.Parts = datatypes.NewJSONSlice(mapParts(req.Parts)) } if req.Name.Value != "" { updatePayload.Name = req.Name.Value } if req.ParentId.Value != "" { - updatePayload.ParentID = sql.NullString{ - String: req.ParentId.Value, - Valid: true, - } + updatePayload.ParentId = utils.Ptr(req.ParentId.Value) } updatePayload.UpdatedAt = req.UpdatedAt @@ -592,28 +553,23 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd return &apiError{err: err} } - if len(file.Parts) > 0 && file.ChannelID != nil { - _, session := auth.GetUser(ctx) - ids := []int{} + keys := []string{cache.Key("files", params.ID)} + if len(file.Parts) > 0 && file.ChannelId != nil { + ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID }) + client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...) + tgc.DeleteMessages(ctx, client, *file.ChannelId, ids) + keys = append(keys, cache.Key("files", "messages", params.ID)) for _, part := range file.Parts { - ids = append(ids, int(part.ID)) + keys = append(keys, cache.Key("files", "location", params.ID, part.ID)) } - client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) - tgc.DeleteMessages(ctx, client, *file.ChannelID, ids) - keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)} - for _, part := range file.Parts { - keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID)) - - } - a.cache.Delete(keys...) } - a.cache.Delete(fmt.Sprintf("files:%s", params.ID)) + a.cache.Delete(keys...) return nil } -func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileID string, userId int64) { +func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileId string, userId int64) { ctx := r.Context() var ( session *models.Session @@ -646,19 +602,17 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi session = &models.Session{UserId: userId} } - file := &api.File{} - - key := fmt.Sprintf("files:%s", fileID) - - err = e.api.cache.Get(key, file) + file, err := cache.Fetch(e.api.cache, cache.Key("files", fileId), 0, func() (*models.File, error) { + var result models.File + if err := e.api.db.Model(&result).Where("id = ?", fileId).First(&result).Error; err != nil { + return nil, err + } + return &result, nil + }) if err != nil { - file, err = e.api.FilesGetById(ctx, api.FilesGetByIdParams{ID: fileID}) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - e.api.cache.Set(key, file, 0) + http.Error(w, err.Error(), http.StatusBadRequest) + return } w.Header().Set("Accept-Ranges", "bytes") @@ -666,16 +620,15 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi var start, end int64 rangeHeader := r.Header.Get("Range") + contentType := defaultContentType - if file.Size.Value == 0 { - w.Header().Set("Content-Type", file.MimeType.Or(defaultContentType)) + if file.MimeType != "" { + contentType = file.MimeType + } + + if file.Size == nil || *file.Size == 0 { + w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Length", "0") - - if rangeHeader != "" { - w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value)) - http.Error(w, "Requested Range Not Satisfiable", http.StatusRequestedRangeNotSatisfiable) - return - } w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": file.Name})) w.WriteHeader(http.StatusOK) return @@ -684,11 +637,11 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi status := http.StatusOK if rangeHeader == "" { start = 0 - end = file.Size.Value - 1 + end = *file.Size - 1 } else { - ranges, err := http_range.Parse(rangeHeader, file.Size.Value) + ranges, err := http_range.Parse(rangeHeader, *file.Size) if err == http_range.ErrNoOverlap { - w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value)) + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", *file.Size)) http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable) return } @@ -702,20 +655,18 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi } start = ranges[0].Start end = ranges[0].End - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size.Value)) + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, *file.Size)) status = http.StatusPartialContent } contentLength := end - start + 1 - mimeType := file.MimeType.Or(defaultContentType) - - w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) - w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileID+strconv.FormatInt(file.Size.Value, 10)))) - w.Header().Set("Last-Modified", file.UpdatedAt.Value.UTC().Format(http.TimeFormat)) + w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileId+strconv.FormatInt(*file.Size, 10)))) + w.Header().Set("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat)) disposition := "inline" @@ -733,7 +684,7 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi return } - tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, file.ChannelId.Value) + tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, *file.ChannelId) if err != nil { http.Error(w, "failed to get bots", http.StatusInternalServerError) @@ -761,9 +712,9 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi multiThreads = 0 } else { - e.api.worker.Set(tokens, file.ChannelId.Value) + e.api.worker.Set(tokens, *file.ChannelId) - token, _ = e.api.worker.Next(file.ChannelId.Value) + token, _ = e.api.worker.Next(*file.ChannelId) client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...) if err != nil { @@ -816,5 +767,16 @@ func (e *extendedService) SharesStream(w http.ResponseWriter, r *http.Request, s http.Error(w, err.Error(), http.StatusUnauthorized) return } - e.FilesStream(w, r, fileId, share.UserID) + e.FilesStream(w, r, fileId, share.UserId) +} + +func mapParts(_parts []api.Part) []api.Part { + return utils.Map(_parts, func(part api.Part) api.Part { + p := api.Part{ID: part.ID} + if part.Salt.Value != "" { + p.Salt = part.Salt + } + return p + }) + } diff --git a/pkg/services/file_query_builder.go b/pkg/services/file_query_builder.go index 19d0f2b..65baf22 100644 --- a/pkg/services/file_query_builder.go +++ b/pkg/services/file_query_builder.go @@ -17,8 +17,7 @@ import ( ) type fileQueryBuilder struct { - db *gorm.DB - selectAllFields bool + db *gorm.DB } type fileResponse struct { @@ -26,7 +25,7 @@ type fileResponse struct { Total int } -var selectedFields = []string{"id", "name", "type", "mime_type", "category", "encrypted", "size", "parent_id", "updated_at"} +var selectedFields = []string{"id", "name", "type", "mime_type", "category", "channel_id", "encrypted", "size", "parent_id", "updated_at"} func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int64) (*api.FileList, error) { query := afb.db.Where("user_id = ?", userId).Where("status = ?", filesQuery.Status.Value) @@ -50,11 +49,7 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int count = res[0].Total } - files := []api.File{} - - for _, file := range res { - files = append(files, *mapper.ToFileOut(file.File, afb.selectAllFields)) - } + files := utils.Map(res, func(item fileResponse) api.File { return *mapper.ToFileOut(item.File) }) return &api.FileList{Items: files, Meta: api.FileListMeta{Count: count, @@ -192,15 +187,10 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File orderField := utils.CamelToSnake(string(filesQuery.Sort.Value)) op := getOrderOperation(filesQuery) - fields := selectedFields - if afb.selectAllFields { - fields = append(fields, "parts", "channel_id") - } - return afb.buildSubqueryCTE(query, filesQuery, userId).Clauses(exclause.NewWith("ranked_scores", afb.db.Model(&models.File{}).Select(orderField, "count(*) OVER () as total", fmt.Sprintf("ROW_NUMBER() OVER (ORDER BY %s %s) AS rank", orderField, strings.ToUpper(string(filesQuery.Order.Value)))). Where(query))).Model(&models.File{}). - Select(fields, "(select total from ranked_scores limit 1) as total"). + Select(selectedFields, "(select total from ranked_scores limit 1) as total"). Where(fmt.Sprintf("%s %s (SELECT %s FROM ranked_scores WHERE rank = ?)", orderField, op, orderField), max((filesQuery.Page.Value-1)*filesQuery.Limit.Value, 1)). Where(query).Order(getOrder(filesQuery)).Limit(filesQuery.Limit.Value) @@ -223,13 +213,6 @@ func getOrder(filesQuery *api.FilesListParams) string { return fmt.Sprintf("%s %s", orderField, strings.ToUpper(string(filesQuery.Order.Value))) } -func max(x int, y int) int { - if x > y { - return x - } - return y -} - func getOrderOperation(filesQuery *api.FilesListParams) string { if filesQuery.Page.Value == 1 { if filesQuery.Order.Value == api.FileQueryOrderAsc { diff --git a/pkg/services/share.go b/pkg/services/share.go index 4d95a45..6beb726 100644 --- a/pkg/services/share.go +++ b/pkg/services/share.go @@ -60,7 +60,7 @@ func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetById } res := &api.FileShareInfo{ Protected: share.Password != nil, - UserId: share.UserID, + UserId: share.UserId, Type: share.Type, Name: share.Name, } @@ -104,16 +104,16 @@ func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListF Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive), Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))), Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))), - Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserID) + Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId) } else { var file models.File - if err := a.db.Where("id = ?", share.FileID).First(&file).Error; err != nil { + if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil { if database.IsRecordNotFoundErr(err) { return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound} } return nil, &apiError{err: err} } - return &api.FileList{Items: []api.File{*mapper.ToFileOut(file, false)}, + return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)}, Meta: api.FileListMeta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil } diff --git a/pkg/services/upload.go b/pkg/services/upload.go index bbf78e8..945e57b 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -48,7 +48,7 @@ func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPar } func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var stats []api.UploadStats err := a.db.Raw(` SELECT @@ -94,7 +94,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400} } - userId, session := auth.GetUser(ctx) + userId := auth.GetUser(ctx) fileStream := req.Content.Data @@ -116,7 +116,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe } if len(tokens) == 0 { - client, err = tgc.AuthClient(ctx, &a.cnf.TG, session) + client, err = tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession) if err != nil { return nil, err } @@ -220,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe Name: params.PartName, UploadId: params.ID, PartId: message.ID, - ChannelID: channelId, + ChannelId: channelId, Size: fileSize, PartNo: int(params.PartNo), UserId: userId, @@ -244,7 +244,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe out = api.UploadPart{ Name: partUpload.Name, PartId: partUpload.PartId, - ChannelId: partUpload.ChannelID, + ChannelId: partUpload.ChannelId, PartNo: partUpload.PartNo, Size: partUpload.Size, Encrypted: partUpload.Encrypted, diff --git a/pkg/services/user.go b/pkg/services/user.go index 8d76674..81fbec4 100644 --- a/pkg/services/user.go +++ b/pkg/services/user.go @@ -16,6 +16,7 @@ import ( "github.com/gotd/td/tgerr" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" + "github.com/tgdrive/teldrive/internal/cache" "github.com/tgdrive/teldrive/internal/tgc" "github.com/tgdrive/teldrive/pkg/models" "github.com/tgdrive/teldrive/pkg/types" @@ -27,8 +28,8 @@ import ( ) func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error { - userId, session := auth.GetUser(ctx) - client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) + userId := auth.GetUser(ctx) + client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...) if len(req.Bots) > 0 { channelId, err := getDefaultChannel(a.db, a.cache, userId) @@ -47,11 +48,11 @@ func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error { func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, error) { - userID, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) channels := make(map[int64]*api.Channel) - peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userID))) + peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId))) iter, err := peerStorage.Iterate(ctx) if err != nil { @@ -67,6 +68,7 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro } } + res := []api.Channel{} for _, channel := range channels { res = append(res, *channel) @@ -79,10 +81,10 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro } func (a *apiService) UsersSyncChannels(ctx context.Context) error { - userId, session := auth.GetUser(ctx) + userId := auth.GetUser(ctx) peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId))) collector := storage.CollectPeers(peerStorage) - client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) + client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...) if err != nil { return &apiError{err: err} } @@ -96,7 +98,9 @@ func (a *apiService) UsersSyncChannels(ctx context.Context) error { } func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) { - userId, userSession := auth.GetUser(ctx) + userId := auth.GetUser(ctx) + + userSession := auth.GetJWTUser(ctx).TgSession client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...) @@ -150,9 +154,8 @@ func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, } func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) { - _, session := auth.GetUser(ctx) - client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...) + client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...) if err != nil { return nil, &apiError{err: err} @@ -194,25 +197,25 @@ func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProf } func (a *apiService) UsersRemoveBots(ctx context.Context) error { - userID, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) - channelId, err := getDefaultChannel(a.db, a.cache, userID) + channelId, err := getDefaultChannel(a.db, a.cache, userId) if err != nil { return &apiError{err: err} } - if err := a.db.Where("user_id = ?", userID).Where("channel_id = ?", channelId). + if err := a.db.Where("user_id = ?", userId).Where("channel_id = ?", channelId). Delete(&models.Bot{}).Error; err != nil { return &apiError{err: err} } - a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId)) + a.cache.Delete(cache.Key("users", "bots", userId, channelId)) return nil } func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRemoveSessionParams) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) session := &models.Session{} @@ -236,15 +239,18 @@ func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRem } func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) { - userID, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) var ( channelId int64 err error ) - channelId, _ = getDefaultChannel(a.db, a.cache, userID) + channelId, err = getDefaultChannel(a.db, a.cache, userId) + if err != nil { + return nil, &apiError{err: err} + } - tokens, err := getBotsToken(a.db, a.cache, userID, channelId) + tokens, err := getBotsToken(a.db, a.cache, userId, channelId) if err != nil { return nil, &apiError{err: err} @@ -253,12 +259,12 @@ func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) { } func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpdate) error { - userId, _ := auth.GetUser(ctx) + userId := auth.GetUser(ctx) - channel := &models.Channel{UserID: userId, Selected: true} + channel := &models.Channel{UserId: userId, Selected: true} if req.ChannelId.Value != 0 { - channel.ChannelID = req.ChannelId.Value + channel.ChannelId = req.ChannelId.Value } if req.ChannelName.Value != "" { channel.ChannelName = req.ChannelName.Value @@ -270,11 +276,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd }).Create(channel).Error; err != nil { return &apiError{err: errors.New("failed to update channel")} } - a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelID). + a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelId). Where("user_id = ?", userId).Update("selected", false) - key := fmt.Sprintf("users:channel:%d", userId) - a.cache.Set(key, channel.ChannelID, 0) + a.cache.Set(cache.Key("users", "channel", userId), channel.ChannelId, 0) return nil } @@ -359,12 +364,12 @@ func (a *apiService) addBots(c context.Context, client *telegram.Client, userId payload := []models.Bot{} for _, info := range botInfoMap { - payload = append(payload, models.Bot{UserID: userId, Token: info.Token, BotID: info.Id, - BotUserName: info.UserName, ChannelID: channelId, + payload = append(payload, models.Bot{UserId: userId, Token: info.Token, BotId: info.Id, + BotUserName: info.UserName, ChannelId: channelId, }) } - a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId)) + a.cache.Delete(cache.Key("users", "bots", userId, channelId)) if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil { return err