mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-05 05:54:55 +08:00
refactor: standardize field naming and improve cache key generation
This commit is contained in:
parent
299f4fa7ec
commit
f0187f4052
26 changed files with 516 additions and 408 deletions
4
go.mod
4
go.mod
|
@ -8,13 +8,12 @@ require (
|
|||
github.com/go-chi/chi/v5 v5.2.0
|
||||
github.com/go-chi/cors v1.2.1
|
||||
github.com/go-co-op/gocron v1.37.0
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gotd/contrib v0.21.0
|
||||
github.com/gotd/td v0.117.0
|
||||
github.com/iyear/connectproxy v0.1.1
|
||||
github.com/mitchellh/go-homedir v1.1.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/ogen-go/ogen v1.8.1
|
||||
github.com/redis/go-redis/v9 v9.7.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
|
@ -50,6 +49,7 @@ require (
|
|||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -79,6 +79,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
|
|||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
|
@ -148,8 +150,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
|
|||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
|
||||
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
|
|
|
@ -42,10 +42,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
|
|||
|
||||
}
|
||||
|
||||
func GetUser(c context.Context) (int64, string) {
|
||||
func GetUser(c context.Context) int64 {
|
||||
authUser, _ := c.Value(authKey).(*types.JWTClaims)
|
||||
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
|
||||
return userId, authUser.TgSession
|
||||
return userId
|
||||
}
|
||||
|
||||
func GetJWTUser(c context.Context) *types.JWTClaims {
|
||||
|
|
74
internal/cache/cache.go
vendored
74
internal/cache/cache.go
vendored
|
@ -2,6 +2,10 @@ package cache
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -29,8 +33,16 @@ func NewCache(ctx context.Context, conf *config.CacheConfig) Cacher {
|
|||
cacher = NewMemoryCache(conf.MaxSize)
|
||||
} else {
|
||||
cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{
|
||||
Addr: conf.RedisAddr,
|
||||
Password: conf.RedisPass,
|
||||
Addr: conf.RedisAddr,
|
||||
Password: conf.RedisPass,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 5,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 1 * time.Hour,
|
||||
}))
|
||||
}
|
||||
return cacher
|
||||
|
@ -119,3 +131,61 @@ func (r *RedisCache) Delete(keys ...string) error {
|
|||
}
|
||||
return r.client.Del(r.ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func Fetch[T any](cache Cacher, key string, expiration time.Duration, fn func() (T, error)) (T, error) {
|
||||
var zero, value T
|
||||
err := cache.Get(key, &value)
|
||||
if err != nil {
|
||||
if errors.Is(err, freecache.ErrNotFound) || errors.Is(err, redis.Nil) {
|
||||
value, err = fn()
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
cache.Set(key, &value, expiration)
|
||||
return value, nil
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func Key(args ...interface{}) string {
|
||||
parts := make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
parts[i] = formatValue(arg)
|
||||
}
|
||||
return strings.Join(parts, ":")
|
||||
}
|
||||
|
||||
func formatValue(v interface{}) string {
|
||||
if v == nil {
|
||||
return "nil"
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr:
|
||||
if val.IsNil() {
|
||||
return "nil"
|
||||
}
|
||||
return formatValue(val.Elem().Interface())
|
||||
case reflect.Array, reflect.Slice:
|
||||
parts := make([]string, val.Len())
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
parts[i] = formatValue(val.Index(i).Interface())
|
||||
}
|
||||
return fmt.Sprintf("[%s]", strings.Join(parts, ","))
|
||||
case reflect.Map:
|
||||
parts := make([]string, 0, val.Len())
|
||||
for _, key := range val.MapKeys() {
|
||||
k := formatValue(key.Interface())
|
||||
v := formatValue(val.MapIndex(key).Interface())
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return fmt.Sprintf("{%s}", strings.Join(parts, ","))
|
||||
case reflect.Struct:
|
||||
return fmt.Sprintf("%+v", v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
|
152
internal/cache/cache_test.go
vendored
152
internal/cache/cache_test.go
vendored
|
@ -5,16 +5,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
|
||||
var value = api.File{
|
||||
var value = models.File{
|
||||
Name: "file.jpeg",
|
||||
Type: "file",
|
||||
}
|
||||
var result api.File
|
||||
var result models.File
|
||||
|
||||
cache := NewMemoryCache(1 * 1024 * 1024)
|
||||
|
||||
|
@ -25,3 +25,149 @@ func TestCache(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, value)
|
||||
}
|
||||
|
||||
func TestKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple strings",
|
||||
args: []interface{}{"user", "123"},
|
||||
expected: "user:123",
|
||||
},
|
||||
{
|
||||
name: "mixed types",
|
||||
args: []interface{}{"cache", 123, true},
|
||||
expected: "cache:123:true",
|
||||
},
|
||||
{
|
||||
name: "with nil",
|
||||
args: []interface{}{"key", nil, "value"},
|
||||
expected: "key:nil:value",
|
||||
},
|
||||
{
|
||||
name: "empty args",
|
||||
args: []interface{}{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single arg",
|
||||
args: []interface{}{"solo"},
|
||||
expected: "solo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := Key(tt.args...)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Key() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatValue(t *testing.T) {
|
||||
type testStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil value",
|
||||
input: nil,
|
||||
expected: "nil",
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
input: "test",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
input: 123,
|
||||
expected: "123",
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
input: true,
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "slice of strings",
|
||||
input: []string{"a", "b", "c"},
|
||||
expected: "[a,b,c]",
|
||||
},
|
||||
{
|
||||
name: "slice of ints",
|
||||
input: []int{1, 2, 3},
|
||||
expected: "[1,2,3]",
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
input: []string{},
|
||||
expected: "[]",
|
||||
},
|
||||
{
|
||||
name: "map string to string",
|
||||
input: map[string]string{"a": "1", "b": "2"},
|
||||
expected: "{a=1,b=2}",
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]string{},
|
||||
expected: "{}",
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
input: testStruct{Name: "John", Age: 30},
|
||||
expected: "{Name:John Age:30}",
|
||||
},
|
||||
{
|
||||
name: "pointer to string",
|
||||
input: func() interface{} { s := "test"; return &s }(),
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: func() interface{} { var s *string; return s }(),
|
||||
expected: "nil",
|
||||
},
|
||||
{
|
||||
name: "nested slice",
|
||||
input: [][]int{{1, 2}, {3, 4}},
|
||||
expected: "[[1,2],[3,4]]",
|
||||
},
|
||||
{
|
||||
name: "complex mixed structure",
|
||||
input: struct {
|
||||
ID int
|
||||
Tags []string
|
||||
Meta map[string]interface{}
|
||||
Valid bool
|
||||
}{
|
||||
ID: 1,
|
||||
Tags: []string{"a", "b"},
|
||||
Meta: map[string]interface{}{"count": 42},
|
||||
Valid: true,
|
||||
},
|
||||
expected: "{ID:1 Tags:[a b] Meta:map[count:42] Valid:true}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatValue(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("formatValue() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,13 +2,13 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
|
@ -144,7 +144,7 @@ func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error {
|
|||
if cfgFile != "" {
|
||||
cl.v.SetConfigFile(cfgFile)
|
||||
} else {
|
||||
home, err := homedir.Dir()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting home directory: %v", err)
|
||||
}
|
||||
|
|
|
@ -2,14 +2,13 @@ package reader
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/config"
|
||||
"github.com/tgdrive/teldrive/internal/crypt"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
)
|
||||
|
||||
|
@ -20,7 +19,7 @@ type Range struct {
|
|||
|
||||
type LinearReader struct {
|
||||
ctx context.Context
|
||||
file *api.File
|
||||
file *models.File
|
||||
parts []types.Part
|
||||
ranges []Range
|
||||
pos int
|
||||
|
@ -52,7 +51,7 @@ func calculatePartByteRanges(start, end, partSize int64) []Range {
|
|||
func NewLinearReader(ctx context.Context,
|
||||
client *tg.Client,
|
||||
cache cache.Cacher,
|
||||
file *api.File,
|
||||
file *models.File,
|
||||
parts []types.Part,
|
||||
start,
|
||||
end int64,
|
||||
|
@ -61,7 +60,7 @@ func NewLinearReader(ctx context.Context,
|
|||
) (io.ReadCloser, error) {
|
||||
|
||||
size := parts[0].Size
|
||||
if file.Encrypted.Value {
|
||||
if file.Encrypted {
|
||||
size = parts[0].DecryptedSize
|
||||
}
|
||||
r := &LinearReader{
|
||||
|
@ -129,22 +128,22 @@ func (r *LinearReader) moveToNextPart() error {
|
|||
|
||||
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
|
||||
currentRange := r.ranges[r.pos]
|
||||
partID := r.parts[currentRange.PartNo].ID
|
||||
partId := r.parts[currentRange.PartNo].ID
|
||||
|
||||
chunkSrc := &chunkSource{
|
||||
channelID: r.file.ChannelId.Value,
|
||||
partID: partID,
|
||||
channelId: *r.file.ChannelId,
|
||||
partId: partId,
|
||||
client: r.client,
|
||||
concurrency: r.concurrency,
|
||||
cache: r.cache,
|
||||
key: fmt.Sprintf("files:location:%s:%d", r.file.ID.Value, partID),
|
||||
key: cache.Key("files", "location", r.file.ID, partId),
|
||||
}
|
||||
|
||||
var (
|
||||
reader io.ReadCloser
|
||||
err error
|
||||
)
|
||||
if r.file.Encrypted.Value {
|
||||
if r.file.Encrypted {
|
||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
||||
reader, err = cipher.DecryptDataSeek(r.ctx,
|
||||
|
|
|
@ -25,8 +25,8 @@ type ChunkSource interface {
|
|||
}
|
||||
|
||||
type chunkSource struct {
|
||||
channelID int64
|
||||
partID int64
|
||||
channelId int64
|
||||
partId int64
|
||||
concurrency int
|
||||
client *tg.Client
|
||||
key string
|
||||
|
@ -46,7 +46,7 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
|||
err = c.cache.Get(c.key, location)
|
||||
|
||||
if err != nil {
|
||||
location, err = tgc.GetLocation(ctx, c.client, c.channelID, c.partID)
|
||||
location, err = tgc.GetLocation(ctx, c.client, c.channelId, c.partId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -5,14 +5,13 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/tgdrive/teldrive/internal/utils"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func NewBoltDB(sessionFile string) (*bbolt.DB, error) {
|
||||
if sessionFile == "" {
|
||||
dir, err := homedir.Dir()
|
||||
dir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
dir = utils.ExecutableDir()
|
||||
} else {
|
||||
|
|
|
@ -12,13 +12,14 @@ import (
|
|||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/tgdrive/teldrive/internal/config"
|
||||
"github.com/tgdrive/teldrive/internal/utils"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInValidChannelID = errors.New("invalid channel id")
|
||||
ErrInValidChannelId = errors.New("invalid channel id")
|
||||
ErrInvalidChannelMessages = errors.New("invalid channel messages")
|
||||
)
|
||||
|
||||
|
@ -33,7 +34,7 @@ func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*t
|
|||
}
|
||||
|
||||
if len(channels.GetChats()) == 0 {
|
||||
return nil, ErrInValidChannelID
|
||||
return nil, ErrInValidChannelId
|
||||
}
|
||||
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
|
||||
}
|
||||
|
@ -71,15 +72,11 @@ func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int6
|
|||
|
||||
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
|
||||
|
||||
msgIds := []tg.InputMessageClass{}
|
||||
|
||||
for _, id := range ids {
|
||||
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
|
||||
}
|
||||
|
||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||
Channel: channel,
|
||||
ID: msgIds,
|
||||
ID: utils.Map(ids, func(id int) tg.InputMessageClass {
|
||||
return &tg.InputMessageID{ID: id}
|
||||
}),
|
||||
}
|
||||
|
||||
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
||||
|
|
|
@ -17,21 +17,21 @@ func NewBotWorker() *BotWorker {
|
|||
}
|
||||
}
|
||||
|
||||
func (w *BotWorker) Set(bots []string, channelID int64) {
|
||||
func (w *BotWorker) Set(bots []string, channelId int64) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if _, ok := w.bots[channelID]; ok {
|
||||
if _, ok := w.bots[channelId]; ok {
|
||||
return
|
||||
}
|
||||
w.bots[channelID] = bots
|
||||
w.currIdx[channelID] = 0
|
||||
w.bots[channelId] = bots
|
||||
w.currIdx[channelId] = 0
|
||||
}
|
||||
|
||||
func (w *BotWorker) Next(channelID int64) (string, int) {
|
||||
func (w *BotWorker) Next(channelId int64) (string, int) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
bots := w.bots[channelID]
|
||||
index := w.currIdx[channelID]
|
||||
w.currIdx[channelID] = (index + 1) % len(bots)
|
||||
bots := w.bots[channelId]
|
||||
index := w.currIdx[channelId]
|
||||
w.currIdx[channelId] = (index + 1) % len(bots)
|
||||
return bots[index], index
|
||||
}
|
||||
|
|
|
@ -5,73 +5,18 @@ import (
|
|||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"reflect"
|
||||
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func CamelToPascalCase(input string) string {
|
||||
var result strings.Builder
|
||||
upperNext := true
|
||||
|
||||
for _, char := range input {
|
||||
if unicode.IsLetter(char) || unicode.IsDigit(char) {
|
||||
if upperNext {
|
||||
result.WriteRune(unicode.ToUpper(char))
|
||||
upperNext = false
|
||||
} else {
|
||||
result.WriteRune(char)
|
||||
}
|
||||
} else {
|
||||
upperNext = true
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func CamelToSnake(input string) string {
|
||||
re := regexp.MustCompile("([a-z0-9])([A-Z])")
|
||||
snake := re.ReplaceAllString(input, "${1}_${2}")
|
||||
return strings.ToLower(snake)
|
||||
}
|
||||
|
||||
func GetField(v interface{}, field string) string {
|
||||
r := reflect.ValueOf(v)
|
||||
f := reflect.Indirect(r).FieldByName(field)
|
||||
fieldValue := f.Interface()
|
||||
|
||||
switch v := fieldValue.(type) {
|
||||
case string:
|
||||
return v
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func Ptr[T any](t T) *T {
|
||||
return &t
|
||||
}
|
||||
|
||||
func BoolPointer(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func IntPointer(b int) *int {
|
||||
return &b
|
||||
}
|
||||
func Int64Pointer(b int64) *int64 {
|
||||
return &b
|
||||
}
|
||||
|
||||
func StringPointer(b string) *string {
|
||||
return &b
|
||||
}
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
|
@ -84,8 +29,34 @@ func PathExists(path string) (bool, error) {
|
|||
}
|
||||
|
||||
func ExecutableDir() string {
|
||||
|
||||
path, _ := os.Executable()
|
||||
|
||||
return filepath.Dir(path)
|
||||
}
|
||||
|
||||
func Filter[T any](slice []T, predicate func(T) bool) []T {
|
||||
var result []T
|
||||
for _, v := range slice {
|
||||
if predicate(v) {
|
||||
result = append(result, v)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func Map[T any, U any](slice []T, mapper func(T) U) []U {
|
||||
var result []U
|
||||
for _, v := range slice {
|
||||
result = append(result, mapper(v))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func Find[T any](slice []T, predicate func(T) bool) (T, bool) {
|
||||
for _, v := range slice {
|
||||
if predicate(v) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
var zero T
|
||||
return zero, false
|
||||
}
|
||||
|
|
|
@ -2,47 +2,41 @@ package mapper
|
|||
|
||||
import (
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/utils"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
)
|
||||
|
||||
func ToFileOut(file models.File, extended bool) *api.File {
|
||||
func ToFileOut(file models.File) *api.File {
|
||||
res := &api.File{
|
||||
ID: api.NewOptString(file.Id),
|
||||
ID: api.NewOptString(file.ID),
|
||||
Name: file.Name,
|
||||
Type: api.FileType(file.Type),
|
||||
MimeType: api.NewOptString(file.MimeType),
|
||||
Encrypted: api.NewOptBool(file.Encrypted),
|
||||
ParentId: api.NewOptString(file.ParentID.String),
|
||||
UpdatedAt: api.NewOptDateTime(file.UpdatedAt),
|
||||
}
|
||||
if file.ParentId != nil {
|
||||
res.ParentId = api.NewOptString(*file.ParentId)
|
||||
}
|
||||
if file.Size != nil {
|
||||
res.Size = api.NewOptInt64(*file.Size)
|
||||
}
|
||||
if file.Category != "" {
|
||||
res.Category = api.NewOptFileCategory(api.FileCategory(file.Category))
|
||||
}
|
||||
if extended {
|
||||
res.Parts = file.Parts
|
||||
if file.ChannelID != nil {
|
||||
res.ChannelId = api.NewOptInt64(*file.ChannelID)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func ToUploadOut(parts []models.Upload) []api.UploadPart {
|
||||
res := []api.UploadPart{}
|
||||
for _, part := range parts {
|
||||
res = append(res, api.UploadPart{
|
||||
return utils.Map(parts, func(part models.Upload) api.UploadPart {
|
||||
return api.UploadPart{
|
||||
Name: part.Name,
|
||||
PartId: part.PartId,
|
||||
ChannelId: part.ChannelID,
|
||||
ChannelId: part.ChannelId,
|
||||
PartNo: part.PartNo,
|
||||
Size: part.Size,
|
||||
Encrypted: part.Encrypted,
|
||||
Salt: api.NewOptString(part.Salt),
|
||||
})
|
||||
|
||||
}
|
||||
return res
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package models
|
|||
|
||||
type Bot struct {
|
||||
Token string `gorm:"type:text;primaryKey"`
|
||||
UserID int64 `gorm:"type:bigint"`
|
||||
BotID int64 `gorm:"type:bigint"`
|
||||
UserId int64 `gorm:"type:bigint"`
|
||||
BotId int64 `gorm:"type:bigint"`
|
||||
BotUserName string `gorm:"type:text"`
|
||||
ChannelID int64 `gorm:"type:bigint"`
|
||||
ChannelId int64 `gorm:"type:bigint"`
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package models
|
||||
|
||||
type Channel struct {
|
||||
ChannelID int64 `gorm:"type:bigint;primaryKey"`
|
||||
ChannelId int64 `gorm:"type:bigint;primaryKey"`
|
||||
ChannelName string `gorm:"type:text"`
|
||||
UserID int64 `gorm:"type:bigint;"`
|
||||
UserId int64 `gorm:"type:bigint;"`
|
||||
Selected bool `gorm:"type:boolean;"`
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
|
@ -9,18 +8,18 @@ import (
|
|||
)
|
||||
|
||||
type File struct {
|
||||
Id string `gorm:"type:uuid;primaryKey;default:uuid7()"`
|
||||
ID string `gorm:"type:uuid;primaryKey;default:uuid7()"`
|
||||
Name string `gorm:"type:text;not null"`
|
||||
Type string `gorm:"type:text;not null"`
|
||||
MimeType string `gorm:"type:text;not null"`
|
||||
Size *int64 `gorm:"type:bigint"`
|
||||
Category string `gorm:"type:text"`
|
||||
Encrypted bool `gorm:"default:false"`
|
||||
UserID int64 `gorm:"type:bigint;not null"`
|
||||
UserId int64 `gorm:"type:bigint;not null"`
|
||||
Status string `gorm:"type:text"`
|
||||
ParentID sql.NullString `gorm:"type:uuid;index"`
|
||||
ParentId *string `gorm:"type:uuid;index"`
|
||||
Parts datatypes.JSONSlice[api.Part] `gorm:"type:jsonb"`
|
||||
ChannelID *int64 `gorm:"type:bigint"`
|
||||
ChannelId *int64 `gorm:"type:bigint"`
|
||||
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
|
||||
}
|
||||
|
|
|
@ -6,10 +6,10 @@ import (
|
|||
|
||||
type FileShare struct {
|
||||
ID string `gorm:"type:uuid;default:uuid_generate_v4();primary_key"`
|
||||
FileID string `gorm:"type:uuid;not null"`
|
||||
FileId string `gorm:"type:uuid;not null"`
|
||||
Password *string `gorm:"type:text"`
|
||||
ExpiresAt *time.Time `gorm:"type:timestamp"`
|
||||
CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
||||
UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
||||
UserID int64 `gorm:"type:bigint;not null"`
|
||||
UserId int64 `gorm:"type:bigint;not null"`
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type Upload struct {
|
|||
PartId int `gorm:"type:integer"`
|
||||
Encrypted bool `gorm:"default:false"`
|
||||
Salt string `gorm:"type:text"`
|
||||
ChannelID int64 `gorm:"type:bigint"`
|
||||
ChannelId int64 `gorm:"type:bigint"`
|
||||
Size int64 `gorm:"type:bigint"`
|
||||
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/gotd/td/tgerr"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/auth"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/logging"
|
||||
"github.com/tgdrive/teldrive/internal/tgc"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
|
@ -80,7 +81,7 @@ func (a *apiService) AuthLogin(ctx context.Context, session *api.SessionCreate)
|
|||
Name: "root",
|
||||
Type: "folder",
|
||||
MimeType: "drive/folder",
|
||||
UserID: session.UserId,
|
||||
UserId: session.UserId,
|
||||
Status: "active",
|
||||
Parts: nil,
|
||||
}
|
||||
|
@ -129,7 +130,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
|
|||
return err
|
||||
})
|
||||
a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{})
|
||||
a.cache.Delete(fmt.Sprintf("sessions:%s", authUser.Hash))
|
||||
a.cache.Delete(cache.Key("sessions", authUser.Hash))
|
||||
return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil
|
||||
}
|
||||
|
||||
|
@ -360,7 +361,7 @@ func pack32BinaryIP4(ip4Address string) []byte {
|
|||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateTgSession(dcID int, authKey []byte, port int) string {
|
||||
func generateTgSession(dcId int, authKey []byte, port int) string {
|
||||
|
||||
dcMaps := map[int]string{
|
||||
1: "149.154.175.53",
|
||||
|
@ -370,8 +371,8 @@ func generateTgSession(dcID int, authKey []byte, port int) string {
|
|||
5: "91.108.56.130",
|
||||
}
|
||||
|
||||
dcIDByte := byte(dcID)
|
||||
serverAddressBytes := pack32BinaryIP4(dcMaps[dcID])
|
||||
dcIDByte := byte(dcId)
|
||||
serverAddressBytes := pack32BinaryIP4(dcMaps[dcId])
|
||||
portByte := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(portByte, uint16(port))
|
||||
|
||||
|
|
|
@ -2,106 +2,88 @@ package services
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/tg"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/crypt"
|
||||
"github.com/tgdrive/teldrive/internal/logging"
|
||||
"github.com/tgdrive/teldrive/internal/tgc"
|
||||
"github.com/tgdrive/teldrive/internal/utils"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *api.File) ([]types.Part, error) {
|
||||
func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) {
|
||||
return cache.Fetch(c, cache.Key("files", "messages", file.ID), 60*time.Minute, func() ([]types.Part, error) {
|
||||
messages, err := tgc.GetMessages(ctx, client.API(), utils.Map(file.Parts, func(part api.Part) int {
|
||||
return part.ID
|
||||
}), *file.ChannelId)
|
||||
|
||||
parts := []types.Part{}
|
||||
|
||||
key := fmt.Sprintf("files:messages:%s", file.ID.Value)
|
||||
|
||||
err := cache.Get(key, &parts)
|
||||
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := []types.Part{}
|
||||
for i, message := range messages {
|
||||
switch item := message.(type) {
|
||||
case *tg.Message:
|
||||
media, ok := item.Media.(*tg.MessageMediaDocument)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
document, ok := media.Document.(*tg.Document)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
part := types.Part{
|
||||
ID: int64(file.Parts[i].ID),
|
||||
Size: document.Size,
|
||||
Salt: file.Parts[i].Salt.Value,
|
||||
}
|
||||
if file.Encrypted {
|
||||
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
if len(parts) != len(file.Parts) {
|
||||
msg := "file parts mismatch"
|
||||
logging.FromContext(ctx).Error(msg, zap.String("name", file.Name),
|
||||
zap.Int("expected", len(file.Parts)), zap.Int("actual", len(parts)))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
ids := []int{}
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
item := message.(*tg.Message)
|
||||
media := item.Media.(*tg.MessageMediaDocument)
|
||||
document := media.Document.(*tg.Document)
|
||||
|
||||
part := types.Part{
|
||||
ID: int64(file.Parts[i].ID),
|
||||
Size: document.Size,
|
||||
Salt: file.Parts[i].Salt.Value,
|
||||
}
|
||||
if file.Encrypted.Value {
|
||||
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
cache.Set(key, &parts, 60*time.Minute)
|
||||
return parts, nil
|
||||
})
|
||||
}
|
||||
|
||||
func getDefaultChannel(db *gorm.DB, cache cache.Cacher, userID int64) (int64, error) {
|
||||
|
||||
var channelId int64
|
||||
key := fmt.Sprintf("users:channel:%d", userID)
|
||||
|
||||
err := cache.Get(key, &channelId)
|
||||
|
||||
if err == nil {
|
||||
return channelId, nil
|
||||
}
|
||||
|
||||
var channelIds []int64
|
||||
db.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
|
||||
Pluck("channel_id", &channelIds)
|
||||
|
||||
if len(channelIds) == 1 {
|
||||
channelId = channelIds[0]
|
||||
cache.Set(key, channelId, 0)
|
||||
}
|
||||
|
||||
if channelId == 0 {
|
||||
return channelId, errors.New("default channel not set")
|
||||
}
|
||||
|
||||
return channelId, nil
|
||||
func getDefaultChannel(db *gorm.DB, c cache.Cacher, userId int64) (int64, error) {
|
||||
return cache.Fetch(c, cache.Key("users", "channel", userId), 0, func() (int64, error) {
|
||||
var channelIds []int64
|
||||
if err := db.Model(&models.Channel{}).Where("user_id = ?", userId).Where("selected = ?", true).
|
||||
Pluck("channel_id", &channelIds).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(channelIds) == 0 {
|
||||
return 0, fmt.Errorf("no default channel found for user %d", userId)
|
||||
}
|
||||
return channelIds[0], nil
|
||||
})
|
||||
}
|
||||
|
||||
func getBotsToken(db *gorm.DB, cache cache.Cacher, userID, channelId int64) ([]string, error) {
|
||||
var bots []string
|
||||
|
||||
key := fmt.Sprintf("users:bots:%d:%d", userID, channelId)
|
||||
|
||||
err := cache.Get(key, &bots)
|
||||
|
||||
if err == nil {
|
||||
func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]string, error) {
|
||||
return cache.Fetch(c, cache.Key("users", "bots", userId, channelId), 0, func() ([]string, error) {
|
||||
var bots []string
|
||||
if err := db.Model(&models.Bot{}).Where("user_id = ?", userId).
|
||||
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bots, nil
|
||||
}
|
||||
|
||||
if err := db.Model(&models.Bot{}).Where("user_id = ?", userID).
|
||||
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache.Set(key, &bots, 0)
|
||||
return bots, nil
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package services
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -20,6 +19,7 @@ import (
|
|||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/auth"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/category"
|
||||
"github.com/tgdrive/teldrive/internal/database"
|
||||
"github.com/tgdrive/teldrive/internal/http_range"
|
||||
|
@ -71,7 +71,7 @@ func randInt64() (int64, error) {
|
|||
b := &buffer{Buf: buf[:]}
|
||||
return b.long()
|
||||
}
|
||||
func isUUID(str string) bool {
|
||||
func isUUId(str string) bool {
|
||||
_, err := uuid.Parse(str)
|
||||
return err == nil
|
||||
}
|
||||
|
@ -97,10 +97,10 @@ func (a *apiService) getFileFromPath(path string, userId int64) (*models.File, e
|
|||
}
|
||||
|
||||
func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStats, error) {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
var stats []api.CategoryStats
|
||||
if err := a.db.Model(&models.File{}).Select("category", "COUNT(*) as total_files", "coalesce(SUM(size),0) as total_size").
|
||||
Where(&models.File{UserID: userId, Type: "file", Status: "active"}).
|
||||
Where(&models.File{UserId: userId, Type: "file", Status: "active"}).
|
||||
Order("category ASC").Group("category").Find(&stats).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
@ -110,9 +110,9 @@ func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStat
|
|||
|
||||
func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params api.FilesCopyParams) (*api.File, error) {
|
||||
|
||||
userId, session := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||
|
||||
var res []models.File
|
||||
|
||||
|
@ -133,12 +133,9 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
}
|
||||
|
||||
err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
||||
ids := []int{}
|
||||
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
}
|
||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID)
|
||||
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
|
||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -198,13 +195,13 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
}
|
||||
|
||||
var parentId string
|
||||
if !isUUID(req.Destination) {
|
||||
if !isUUId(req.Destination) {
|
||||
var destRes []models.File
|
||||
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
|
||||
Scan(&destRes).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
parentId = destRes[0].Id
|
||||
parentId = destRes[0].ID
|
||||
} else {
|
||||
parentId = req.Destination
|
||||
}
|
||||
|
@ -218,13 +215,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
if len(newIds) > 0 {
|
||||
dbFile.Parts = datatypes.NewJSONSlice(newIds)
|
||||
}
|
||||
dbFile.UserID = userId
|
||||
dbFile.UserId = userId
|
||||
dbFile.Status = "active"
|
||||
dbFile.ParentID = sql.NullString{
|
||||
String: parentId,
|
||||
Valid: true,
|
||||
}
|
||||
dbFile.ChannelID = &channelId
|
||||
dbFile.ParentId = utils.Ptr(parentId)
|
||||
dbFile.ChannelId = &channelId
|
||||
dbFile.Encrypted = file.Encrypted
|
||||
dbFile.Category = string(file.Category)
|
||||
if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
|
||||
|
@ -237,11 +231,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
|||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
||||
return mapper.ToFileOut(dbFile, false), nil
|
||||
return mapper.ToFileOut(dbFile), nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.File, error) {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
var (
|
||||
fileDB models.File
|
||||
|
@ -267,15 +261,9 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
|||
if err != nil {
|
||||
return nil, &apiError{err: err, code: 404}
|
||||
}
|
||||
fileDB.ParentID = sql.NullString{
|
||||
String: parent.Id,
|
||||
Valid: true,
|
||||
}
|
||||
fileDB.ParentId = utils.Ptr(parent.ID)
|
||||
} else if fileIn.ParentId.Value != "" {
|
||||
fileDB.ParentID = sql.NullString{
|
||||
String: fileIn.ParentId.Value,
|
||||
Valid: true,
|
||||
}
|
||||
fileDB.ParentId = utils.Ptr(fileIn.ParentId.Value)
|
||||
|
||||
}
|
||||
|
||||
|
@ -291,25 +279,17 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
|||
} else {
|
||||
channelId = fileIn.ChannelId.Value
|
||||
}
|
||||
fileDB.ChannelID = &channelId
|
||||
fileDB.ChannelId = &channelId
|
||||
fileDB.MimeType = fileIn.MimeType.Value
|
||||
fileDB.Category = string(category.GetCategory(fileIn.Name))
|
||||
if len(fileIn.Parts) > 0 {
|
||||
parts := []api.Part{}
|
||||
for _, part := range fileIn.Parts {
|
||||
p := api.Part{ID: part.ID}
|
||||
if part.Salt.Value != "" {
|
||||
p.Salt = part.Salt
|
||||
}
|
||||
parts = append(parts, p)
|
||||
}
|
||||
fileDB.Parts = datatypes.NewJSONSlice(parts)
|
||||
fileDB.Parts = datatypes.NewJSONSlice(mapParts(fileIn.Parts))
|
||||
}
|
||||
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
|
||||
}
|
||||
fileDB.Name = fileIn.Name
|
||||
fileDB.Type = string(fileIn.Type)
|
||||
fileDB.UserID = userId
|
||||
fileDB.UserId = userId
|
||||
fileDB.Status = "active"
|
||||
fileDB.Encrypted = fileIn.Encrypted.Value
|
||||
if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() {
|
||||
|
@ -323,11 +303,11 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
|||
}
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
return mapper.ToFileOut(fileDB, false), nil
|
||||
return mapper.ToFileOut(fileDB), nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCreate, params api.FilesCreateShareParams) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
var fileShare models.FileShare
|
||||
|
||||
|
@ -339,11 +319,11 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
|
|||
fileShare.Password = utils.Ptr(string(bytes))
|
||||
}
|
||||
|
||||
fileShare.FileID = params.ID
|
||||
fileShare.FileId = params.ID
|
||||
if req.ExpiresAt.IsSet() {
|
||||
fileShare.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
||||
}
|
||||
fileShare.UserID = userId
|
||||
fileShare.UserId = userId
|
||||
|
||||
if err := a.db.Create(&fileShare).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
|
@ -353,7 +333,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
|
|||
}
|
||||
|
||||
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
@ -362,7 +342,7 @@ func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error
|
|||
}
|
||||
|
||||
func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDeleteShareParams) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
var deletedShare models.FileShare
|
||||
|
||||
|
@ -371,14 +351,14 @@ func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDelet
|
|||
return &apiError{err: err}
|
||||
}
|
||||
if deletedShare.ID != "" {
|
||||
a.cache.Delete(fmt.Sprintf("shares:%s", deletedShare.ID))
|
||||
a.cache.Delete(cache.Key("shared", deletedShare.ID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreate, params api.FilesEditShareParams) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
var fileShareUpdate models.FileShare
|
||||
|
||||
|
@ -387,7 +367,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
|
|||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
fileShareUpdate.Password = utils.StringPointer(string(bytes))
|
||||
fileShareUpdate.Password = utils.Ptr(string(bytes))
|
||||
}
|
||||
if req.ExpiresAt.IsSet() {
|
||||
fileShareUpdate.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
||||
|
@ -403,26 +383,25 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
|
|||
|
||||
func (a *apiService) FilesGetById(ctx context.Context, params api.FilesGetByIdParams) (*api.File, error) {
|
||||
var result []fullFileDB
|
||||
notFoundResponse := &apiError{err: errors.New("file not found"), code: 404}
|
||||
if err := a.db.Model(&models.File{}).Select("*",
|
||||
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(id))").
|
||||
Where("id = ?", params.ID).Scan(&result).Error; err != nil {
|
||||
if database.IsRecordNotFoundErr(err) {
|
||||
return nil, notFoundResponse
|
||||
}
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, notFoundResponse
|
||||
return nil, &apiError{err: errors.New("file not found"), code: 404}
|
||||
}
|
||||
res := mapper.ToFileOut(result[0].File, true)
|
||||
res := mapper.ToFileOut(result[0].File)
|
||||
res.Path = api.NewOptString(result[0].Path)
|
||||
if result[0].ChannelId != nil {
|
||||
res.ChannelId = api.NewOptInt64(*result[0].ChannelId)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) (*api.FileList, error) {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
queryBuilder := &fileQueryBuilder{db: a.db}
|
||||
|
||||
|
@ -430,7 +409,7 @@ func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams)
|
|||
}
|
||||
|
||||
func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
if err := a.db.Exec("select * from teldrive.create_directories(?, ?)", userId, req.Path).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
|
@ -439,18 +418,18 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
|||
}
|
||||
|
||||
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
items := pgtype.Array[string]{
|
||||
Elements: req.Ids,
|
||||
Valid: true,
|
||||
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
|
||||
}
|
||||
if !isUUID(req.Destination) {
|
||||
if !isUUId(req.Destination) {
|
||||
r, err := a.getFileFromPath(req.Destination, userId)
|
||||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
req.Destination = r.Id
|
||||
req.Destination = r.ID
|
||||
}
|
||||
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
|
||||
Update("parent_id", req.Destination).Error; err != nil {
|
||||
|
@ -462,7 +441,7 @@ func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
|||
}
|
||||
|
||||
func (a *apiService) FilesShareByid(ctx context.Context, params api.FilesShareByidParams) (*api.FileShare, error) {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
var result []models.FileShare
|
||||
|
||||
notFoundErr := &apiError{err: errors.New("invalid share"), code: 404}
|
||||
|
@ -500,15 +479,7 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
|
|||
updateDb.Name = req.Name.Value
|
||||
}
|
||||
if len(req.Parts) > 0 {
|
||||
parts := []api.Part{}
|
||||
for _, part := range req.Parts {
|
||||
p := api.Part{ID: part.ID}
|
||||
if part.Salt.Value != "" {
|
||||
p.Salt = part.Salt
|
||||
}
|
||||
parts = append(parts, p)
|
||||
}
|
||||
updateDb.Parts = datatypes.NewJSONSlice(parts)
|
||||
updateDb.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
|
||||
}
|
||||
if req.Size.Value != 0 {
|
||||
updateDb.Size = utils.Ptr(req.Size.Value)
|
||||
|
@ -523,17 +494,18 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
|
|||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
||||
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
|
||||
a.cache.Delete(cache.Key("files", params.ID))
|
||||
|
||||
file := models.File{}
|
||||
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
return mapper.ToFileOut(file, false), nil
|
||||
return mapper.ToFileOut(file), nil
|
||||
}
|
||||
|
||||
func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
var file models.File
|
||||
|
||||
|
@ -545,29 +517,18 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
|||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
updatePayload.ChannelID = &channelId
|
||||
updatePayload.ChannelId = &channelId
|
||||
} else {
|
||||
updatePayload.ChannelID = &req.ChannelId.Value
|
||||
updatePayload.ChannelId = &req.ChannelId.Value
|
||||
}
|
||||
if len(req.Parts) > 0 {
|
||||
parts := []api.Part{}
|
||||
for _, part := range req.Parts {
|
||||
p := api.Part{ID: part.ID}
|
||||
if part.Salt.Value != "" {
|
||||
p.Salt = part.Salt
|
||||
}
|
||||
parts = append(parts, p)
|
||||
}
|
||||
updatePayload.Parts = datatypes.NewJSONSlice(parts)
|
||||
updatePayload.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
|
||||
}
|
||||
if req.Name.Value != "" {
|
||||
updatePayload.Name = req.Name.Value
|
||||
}
|
||||
if req.ParentId.Value != "" {
|
||||
updatePayload.ParentID = sql.NullString{
|
||||
String: req.ParentId.Value,
|
||||
Valid: true,
|
||||
}
|
||||
updatePayload.ParentId = utils.Ptr(req.ParentId.Value)
|
||||
}
|
||||
|
||||
updatePayload.UpdatedAt = req.UpdatedAt
|
||||
|
@ -592,28 +553,23 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
|||
return &apiError{err: err}
|
||||
}
|
||||
|
||||
if len(file.Parts) > 0 && file.ChannelID != nil {
|
||||
_, session := auth.GetUser(ctx)
|
||||
ids := []int{}
|
||||
keys := []string{cache.Key("files", params.ID)}
|
||||
if len(file.Parts) > 0 && file.ChannelId != nil {
|
||||
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||
tgc.DeleteMessages(ctx, client, *file.ChannelId, ids)
|
||||
keys = append(keys, cache.Key("files", "messages", params.ID))
|
||||
for _, part := range file.Parts {
|
||||
ids = append(ids, int(part.ID))
|
||||
keys = append(keys, cache.Key("files", "location", params.ID, part.ID))
|
||||
}
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
||||
tgc.DeleteMessages(ctx, client, *file.ChannelID, ids)
|
||||
keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)}
|
||||
for _, part := range file.Parts {
|
||||
keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID))
|
||||
|
||||
}
|
||||
a.cache.Delete(keys...)
|
||||
|
||||
}
|
||||
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
|
||||
a.cache.Delete(keys...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileID string, userId int64) {
|
||||
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileId string, userId int64) {
|
||||
ctx := r.Context()
|
||||
var (
|
||||
session *models.Session
|
||||
|
@ -646,19 +602,17 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
session = &models.Session{UserId: userId}
|
||||
}
|
||||
|
||||
file := &api.File{}
|
||||
|
||||
key := fmt.Sprintf("files:%s", fileID)
|
||||
|
||||
err = e.api.cache.Get(key, file)
|
||||
file, err := cache.Fetch(e.api.cache, cache.Key("files", fileId), 0, func() (*models.File, error) {
|
||||
var result models.File
|
||||
if err := e.api.db.Model(&result).Where("id = ?", fileId).First(&result).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
file, err = e.api.FilesGetById(ctx, api.FilesGetByIdParams{ID: fileID})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
e.api.cache.Set(key, file, 0)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
|
@ -666,16 +620,15 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
var start, end int64
|
||||
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
contentType := defaultContentType
|
||||
|
||||
if file.Size.Value == 0 {
|
||||
w.Header().Set("Content-Type", file.MimeType.Or(defaultContentType))
|
||||
if file.MimeType != "" {
|
||||
contentType = file.MimeType
|
||||
}
|
||||
|
||||
if file.Size == nil || *file.Size == 0 {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Content-Length", "0")
|
||||
|
||||
if rangeHeader != "" {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
|
||||
http.Error(w, "Requested Range Not Satisfiable", http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": file.Name}))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
|
@ -684,11 +637,11 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
status := http.StatusOK
|
||||
if rangeHeader == "" {
|
||||
start = 0
|
||||
end = file.Size.Value - 1
|
||||
end = *file.Size - 1
|
||||
} else {
|
||||
ranges, err := http_range.Parse(rangeHeader, file.Size.Value)
|
||||
ranges, err := http_range.Parse(rangeHeader, *file.Size)
|
||||
if err == http_range.ErrNoOverlap {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", *file.Size))
|
||||
http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
|
@ -702,20 +655,18 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
}
|
||||
start = ranges[0].Start
|
||||
end = ranges[0].End
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size.Value))
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, *file.Size))
|
||||
status = http.StatusPartialContent
|
||||
|
||||
}
|
||||
|
||||
contentLength := end - start + 1
|
||||
|
||||
mimeType := file.MimeType.Or(defaultContentType)
|
||||
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
||||
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileID+strconv.FormatInt(file.Size.Value, 10))))
|
||||
w.Header().Set("Last-Modified", file.UpdatedAt.Value.UTC().Format(http.TimeFormat))
|
||||
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileId+strconv.FormatInt(*file.Size, 10))))
|
||||
w.Header().Set("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat))
|
||||
|
||||
disposition := "inline"
|
||||
|
||||
|
@ -733,7 +684,7 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
return
|
||||
}
|
||||
|
||||
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, file.ChannelId.Value)
|
||||
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, *file.ChannelId)
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get bots", http.StatusInternalServerError)
|
||||
|
@ -761,9 +712,9 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
|||
multiThreads = 0
|
||||
|
||||
} else {
|
||||
e.api.worker.Set(tokens, file.ChannelId.Value)
|
||||
e.api.worker.Set(tokens, *file.ChannelId)
|
||||
|
||||
token, _ = e.api.worker.Next(file.ChannelId.Value)
|
||||
token, _ = e.api.worker.Next(*file.ChannelId)
|
||||
|
||||
client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...)
|
||||
if err != nil {
|
||||
|
@ -816,5 +767,16 @@ func (e *extendedService) SharesStream(w http.ResponseWriter, r *http.Request, s
|
|||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
e.FilesStream(w, r, fileId, share.UserID)
|
||||
e.FilesStream(w, r, fileId, share.UserId)
|
||||
}
|
||||
|
||||
func mapParts(_parts []api.Part) []api.Part {
|
||||
return utils.Map(_parts, func(part api.Part) api.Part {
|
||||
p := api.Part{ID: part.ID}
|
||||
if part.Salt.Value != "" {
|
||||
p.Salt = part.Salt
|
||||
}
|
||||
return p
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -17,8 +17,7 @@ import (
|
|||
)
|
||||
|
||||
type fileQueryBuilder struct {
|
||||
db *gorm.DB
|
||||
selectAllFields bool
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type fileResponse struct {
|
||||
|
@ -26,7 +25,7 @@ type fileResponse struct {
|
|||
Total int
|
||||
}
|
||||
|
||||
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "encrypted", "size", "parent_id", "updated_at"}
|
||||
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "channel_id", "encrypted", "size", "parent_id", "updated_at"}
|
||||
|
||||
func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int64) (*api.FileList, error) {
|
||||
query := afb.db.Where("user_id = ?", userId).Where("status = ?", filesQuery.Status.Value)
|
||||
|
@ -50,11 +49,7 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int
|
|||
count = res[0].Total
|
||||
}
|
||||
|
||||
files := []api.File{}
|
||||
|
||||
for _, file := range res {
|
||||
files = append(files, *mapper.ToFileOut(file.File, afb.selectAllFields))
|
||||
}
|
||||
files := utils.Map(res, func(item fileResponse) api.File { return *mapper.ToFileOut(item.File) })
|
||||
|
||||
return &api.FileList{Items: files,
|
||||
Meta: api.FileListMeta{Count: count,
|
||||
|
@ -192,15 +187,10 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File
|
|||
orderField := utils.CamelToSnake(string(filesQuery.Sort.Value))
|
||||
op := getOrderOperation(filesQuery)
|
||||
|
||||
fields := selectedFields
|
||||
if afb.selectAllFields {
|
||||
fields = append(fields, "parts", "channel_id")
|
||||
}
|
||||
|
||||
return afb.buildSubqueryCTE(query, filesQuery, userId).Clauses(exclause.NewWith("ranked_scores", afb.db.Model(&models.File{}).Select(orderField, "count(*) OVER () as total",
|
||||
fmt.Sprintf("ROW_NUMBER() OVER (ORDER BY %s %s) AS rank", orderField, strings.ToUpper(string(filesQuery.Order.Value)))).
|
||||
Where(query))).Model(&models.File{}).
|
||||
Select(fields, "(select total from ranked_scores limit 1) as total").
|
||||
Select(selectedFields, "(select total from ranked_scores limit 1) as total").
|
||||
Where(fmt.Sprintf("%s %s (SELECT %s FROM ranked_scores WHERE rank = ?)", orderField, op, orderField),
|
||||
max((filesQuery.Page.Value-1)*filesQuery.Limit.Value, 1)).
|
||||
Where(query).Order(getOrder(filesQuery)).Limit(filesQuery.Limit.Value)
|
||||
|
@ -223,13 +213,6 @@ func getOrder(filesQuery *api.FilesListParams) string {
|
|||
return fmt.Sprintf("%s %s", orderField, strings.ToUpper(string(filesQuery.Order.Value)))
|
||||
}
|
||||
|
||||
func max(x int, y int) int {
|
||||
if x > y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
func getOrderOperation(filesQuery *api.FilesListParams) string {
|
||||
if filesQuery.Page.Value == 1 {
|
||||
if filesQuery.Order.Value == api.FileQueryOrderAsc {
|
||||
|
|
|
@ -60,7 +60,7 @@ func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetById
|
|||
}
|
||||
res := &api.FileShareInfo{
|
||||
Protected: share.Password != nil,
|
||||
UserId: share.UserID,
|
||||
UserId: share.UserId,
|
||||
Type: share.Type,
|
||||
Name: share.Name,
|
||||
}
|
||||
|
@ -104,16 +104,16 @@ func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListF
|
|||
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
|
||||
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
|
||||
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
|
||||
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserID)
|
||||
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId)
|
||||
} else {
|
||||
var file models.File
|
||||
if err := a.db.Where("id = ?", share.FileID).First(&file).Error; err != nil {
|
||||
if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil {
|
||||
if database.IsRecordNotFoundErr(err) {
|
||||
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
|
||||
}
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file, false)},
|
||||
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)},
|
||||
Meta: api.FileListMeta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPar
|
|||
}
|
||||
|
||||
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
var stats []api.UploadStats
|
||||
err := a.db.Raw(`
|
||||
SELECT
|
||||
|
@ -94,7 +94,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
|||
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
|
||||
}
|
||||
|
||||
userId, session := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
fileStream := req.Content.Data
|
||||
|
||||
|
@ -116,7 +116,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
|||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
client, err = tgc.AuthClient(ctx, &a.cnf.TG, session)
|
||||
client, err = tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -220,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
|||
Name: params.PartName,
|
||||
UploadId: params.ID,
|
||||
PartId: message.ID,
|
||||
ChannelID: channelId,
|
||||
ChannelId: channelId,
|
||||
Size: fileSize,
|
||||
PartNo: int(params.PartNo),
|
||||
UserId: userId,
|
||||
|
@ -244,7 +244,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
|||
out = api.UploadPart{
|
||||
Name: partUpload.Name,
|
||||
PartId: partUpload.PartId,
|
||||
ChannelId: partUpload.ChannelID,
|
||||
ChannelId: partUpload.ChannelId,
|
||||
PartNo: partUpload.PartNo,
|
||||
Size: partUpload.Size,
|
||||
Encrypted: partUpload.Encrypted,
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/gotd/td/tgerr"
|
||||
"github.com/tgdrive/teldrive/internal/api"
|
||||
"github.com/tgdrive/teldrive/internal/auth"
|
||||
"github.com/tgdrive/teldrive/internal/cache"
|
||||
"github.com/tgdrive/teldrive/internal/tgc"
|
||||
"github.com/tgdrive/teldrive/pkg/models"
|
||||
"github.com/tgdrive/teldrive/pkg/types"
|
||||
|
@ -27,8 +28,8 @@ import (
|
|||
)
|
||||
|
||||
func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
|
||||
userId, session := auth.GetUser(ctx)
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
||||
userId := auth.GetUser(ctx)
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||
|
||||
if len(req.Bots) > 0 {
|
||||
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
||||
|
@ -47,11 +48,11 @@ func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
|
|||
|
||||
func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, error) {
|
||||
|
||||
userID, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
channels := make(map[int64]*api.Channel)
|
||||
|
||||
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userID)))
|
||||
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
|
||||
|
||||
iter, err := peerStorage.Iterate(ctx)
|
||||
if err != nil {
|
||||
|
@ -67,6 +68,7 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
res := []api.Channel{}
|
||||
for _, channel := range channels {
|
||||
res = append(res, *channel)
|
||||
|
@ -79,10 +81,10 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
|
|||
}
|
||||
|
||||
func (a *apiService) UsersSyncChannels(ctx context.Context) error {
|
||||
userId, session := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
|
||||
collector := storage.CollectPeers(peerStorage)
|
||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
@ -96,7 +98,9 @@ func (a *apiService) UsersSyncChannels(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) {
|
||||
userId, userSession := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
userSession := auth.GetJWTUser(ctx).TgSession
|
||||
|
||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
|
||||
|
||||
|
@ -150,9 +154,8 @@ func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession,
|
|||
}
|
||||
|
||||
func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) {
|
||||
_, session := auth.GetUser(ctx)
|
||||
|
||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
|
@ -194,25 +197,25 @@ func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProf
|
|||
}
|
||||
|
||||
func (a *apiService) UsersRemoveBots(ctx context.Context) error {
|
||||
userID, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
channelId, err := getDefaultChannel(a.db, a.cache, userID)
|
||||
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
||||
if err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
||||
if err := a.db.Where("user_id = ?", userID).Where("channel_id = ?", channelId).
|
||||
if err := a.db.Where("user_id = ?", userId).Where("channel_id = ?", channelId).
|
||||
Delete(&models.Bot{}).Error; err != nil {
|
||||
return &apiError{err: err}
|
||||
}
|
||||
|
||||
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId))
|
||||
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRemoveSessionParams) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
session := &models.Session{}
|
||||
|
||||
|
@ -236,15 +239,18 @@ func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRem
|
|||
}
|
||||
|
||||
func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
|
||||
userID, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
var (
|
||||
channelId int64
|
||||
err error
|
||||
)
|
||||
|
||||
channelId, _ = getDefaultChannel(a.db, a.cache, userID)
|
||||
channelId, err = getDefaultChannel(a.db, a.cache, userId)
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
}
|
||||
|
||||
tokens, err := getBotsToken(a.db, a.cache, userID, channelId)
|
||||
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
|
||||
|
||||
if err != nil {
|
||||
return nil, &apiError{err: err}
|
||||
|
@ -253,12 +259,12 @@ func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
|
|||
}
|
||||
|
||||
func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpdate) error {
|
||||
userId, _ := auth.GetUser(ctx)
|
||||
userId := auth.GetUser(ctx)
|
||||
|
||||
channel := &models.Channel{UserID: userId, Selected: true}
|
||||
channel := &models.Channel{UserId: userId, Selected: true}
|
||||
|
||||
if req.ChannelId.Value != 0 {
|
||||
channel.ChannelID = req.ChannelId.Value
|
||||
channel.ChannelId = req.ChannelId.Value
|
||||
}
|
||||
if req.ChannelName.Value != "" {
|
||||
channel.ChannelName = req.ChannelName.Value
|
||||
|
@ -270,11 +276,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd
|
|||
}).Create(channel).Error; err != nil {
|
||||
return &apiError{err: errors.New("failed to update channel")}
|
||||
}
|
||||
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelID).
|
||||
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelId).
|
||||
Where("user_id = ?", userId).Update("selected", false)
|
||||
|
||||
key := fmt.Sprintf("users:channel:%d", userId)
|
||||
a.cache.Set(key, channel.ChannelID, 0)
|
||||
a.cache.Set(cache.Key("users", "channel", userId), channel.ChannelId, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -359,12 +364,12 @@ func (a *apiService) addBots(c context.Context, client *telegram.Client, userId
|
|||
payload := []models.Bot{}
|
||||
|
||||
for _, info := range botInfoMap {
|
||||
payload = append(payload, models.Bot{UserID: userId, Token: info.Token, BotID: info.Id,
|
||||
BotUserName: info.UserName, ChannelID: channelId,
|
||||
payload = append(payload, models.Bot{UserId: userId, Token: info.Token, BotId: info.Id,
|
||||
BotUserName: info.UserName, ChannelId: channelId,
|
||||
})
|
||||
}
|
||||
|
||||
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId))
|
||||
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
|
||||
|
||||
if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
|
||||
return err
|
||||
|
|
Loading…
Add table
Reference in a new issue