mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-09-11 17:04:59 +08:00
refactor: standardize field naming and improve cache key generation
This commit is contained in:
parent
299f4fa7ec
commit
f0187f4052
26 changed files with 516 additions and 408 deletions
4
go.mod
4
go.mod
|
@ -8,13 +8,12 @@ require (
|
||||||
github.com/go-chi/chi/v5 v5.2.0
|
github.com/go-chi/chi/v5 v5.2.0
|
||||||
github.com/go-chi/cors v1.2.1
|
github.com/go-chi/cors v1.2.1
|
||||||
github.com/go-co-op/gocron v1.37.0
|
github.com/go-co-op/gocron v1.37.0
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gotd/contrib v0.21.0
|
github.com/gotd/contrib v0.21.0
|
||||||
github.com/gotd/td v0.117.0
|
github.com/gotd/td v0.117.0
|
||||||
github.com/iyear/connectproxy v0.1.1
|
github.com/iyear/connectproxy v0.1.1
|
||||||
github.com/mitchellh/go-homedir v1.1.0
|
|
||||||
github.com/mitchellh/mapstructure v1.5.0
|
|
||||||
github.com/ogen-go/ogen v1.8.1
|
github.com/ogen-go/ogen v1.8.1
|
||||||
github.com/redis/go-redis/v9 v9.7.0
|
github.com/redis/go-redis/v9 v9.7.0
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
|
@ -50,6 +49,7 @@ require (
|
||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -79,6 +79,8 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||||
|
@ -148,8 +150,6 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
|
||||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||||
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
|
github.com/microsoft/go-mssqldb v1.8.0 h1:7cyZ/AT7ycDsEoWPIXibd+aVKFtteUNhDGf3aobP+tw=
|
||||||
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
|
github.com/microsoft/go-mssqldb v1.8.0/go.mod h1:6znkekS3T2vp0waiMhen4GPU1BiAsrP+iXHcE7a7rFo=
|
||||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
|
||||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
|
|
|
@ -42,10 +42,10 @@ func Decode(secret string, token string) (*types.JWTClaims, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(c context.Context) (int64, string) {
|
func GetUser(c context.Context) int64 {
|
||||||
authUser, _ := c.Value(authKey).(*types.JWTClaims)
|
authUser, _ := c.Value(authKey).(*types.JWTClaims)
|
||||||
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
|
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
|
||||||
return userId, authUser.TgSession
|
return userId
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetJWTUser(c context.Context) *types.JWTClaims {
|
func GetJWTUser(c context.Context) *types.JWTClaims {
|
||||||
|
|
70
internal/cache/cache.go
vendored
70
internal/cache/cache.go
vendored
|
@ -2,6 +2,10 @@ package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -31,6 +35,14 @@ func NewCache(ctx context.Context, conf *config.CacheConfig) Cacher {
|
||||||
cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{
|
cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{
|
||||||
Addr: conf.RedisAddr,
|
Addr: conf.RedisAddr,
|
||||||
Password: conf.RedisPass,
|
Password: conf.RedisPass,
|
||||||
|
DialTimeout: 5 * time.Second,
|
||||||
|
ReadTimeout: 3 * time.Second,
|
||||||
|
WriteTimeout: 3 * time.Second,
|
||||||
|
PoolSize: 10,
|
||||||
|
MinIdleConns: 5,
|
||||||
|
MaxIdleConns: 10,
|
||||||
|
ConnMaxIdleTime: 5 * time.Minute,
|
||||||
|
ConnMaxLifetime: 1 * time.Hour,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
return cacher
|
return cacher
|
||||||
|
@ -119,3 +131,61 @@ func (r *RedisCache) Delete(keys ...string) error {
|
||||||
}
|
}
|
||||||
return r.client.Del(r.ctx, keys...).Err()
|
return r.client.Del(r.ctx, keys...).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Fetch[T any](cache Cacher, key string, expiration time.Duration, fn func() (T, error)) (T, error) {
|
||||||
|
var zero, value T
|
||||||
|
err := cache.Get(key, &value)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, freecache.ErrNotFound) || errors.Is(err, redis.Nil) {
|
||||||
|
value, err = fn()
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
cache.Set(key, &value, expiration)
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Key(args ...interface{}) string {
|
||||||
|
parts := make([]string, len(args))
|
||||||
|
for i, arg := range args {
|
||||||
|
parts[i] = formatValue(arg)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatValue(v interface{}) string {
|
||||||
|
if v == nil {
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.ValueOf(v)
|
||||||
|
switch val.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if val.IsNil() {
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
return formatValue(val.Elem().Interface())
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
parts := make([]string, val.Len())
|
||||||
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
parts[i] = formatValue(val.Index(i).Interface())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[%s]", strings.Join(parts, ","))
|
||||||
|
case reflect.Map:
|
||||||
|
parts := make([]string, 0, val.Len())
|
||||||
|
for _, key := range val.MapKeys() {
|
||||||
|
k := formatValue(key.Interface())
|
||||||
|
v := formatValue(val.MapIndex(key).Interface())
|
||||||
|
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("{%s}", strings.Join(parts, ","))
|
||||||
|
case reflect.Struct:
|
||||||
|
return fmt.Sprintf("%+v", v)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
152
internal/cache/cache_test.go
vendored
152
internal/cache/cache_test.go
vendored
|
@ -5,16 +5,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCache(t *testing.T) {
|
func TestCache(t *testing.T) {
|
||||||
|
|
||||||
var value = api.File{
|
var value = models.File{
|
||||||
Name: "file.jpeg",
|
Name: "file.jpeg",
|
||||||
Type: "file",
|
Type: "file",
|
||||||
}
|
}
|
||||||
var result api.File
|
var result models.File
|
||||||
|
|
||||||
cache := NewMemoryCache(1 * 1024 * 1024)
|
cache := NewMemoryCache(1 * 1024 * 1024)
|
||||||
|
|
||||||
|
@ -25,3 +25,149 @@ func TestCache(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, result, value)
|
assert.Equal(t, result, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple strings",
|
||||||
|
args: []interface{}{"user", "123"},
|
||||||
|
expected: "user:123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed types",
|
||||||
|
args: []interface{}{"cache", 123, true},
|
||||||
|
expected: "cache:123:true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with nil",
|
||||||
|
args: []interface{}{"key", nil, "value"},
|
||||||
|
expected: "key:nil:value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty args",
|
||||||
|
args: []interface{}{},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single arg",
|
||||||
|
args: []interface{}{"solo"},
|
||||||
|
expected: "solo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := Key(tt.args...)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Key() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatValue(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil value",
|
||||||
|
input: nil,
|
||||||
|
expected: "nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
input: "test",
|
||||||
|
expected: "test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "integer",
|
||||||
|
input: 123,
|
||||||
|
expected: "123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean",
|
||||||
|
input: true,
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slice of strings",
|
||||||
|
input: []string{"a", "b", "c"},
|
||||||
|
expected: "[a,b,c]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slice of ints",
|
||||||
|
input: []int{1, 2, 3},
|
||||||
|
expected: "[1,2,3]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty slice",
|
||||||
|
input: []string{},
|
||||||
|
expected: "[]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map string to string",
|
||||||
|
input: map[string]string{"a": "1", "b": "2"},
|
||||||
|
expected: "{a=1,b=2}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
input: map[string]string{},
|
||||||
|
expected: "{}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct",
|
||||||
|
input: testStruct{Name: "John", Age: 30},
|
||||||
|
expected: "{Name:John Age:30}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pointer to string",
|
||||||
|
input: func() interface{} { s := "test"; return &s }(),
|
||||||
|
expected: "test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil pointer",
|
||||||
|
input: func() interface{} { var s *string; return s }(),
|
||||||
|
expected: "nil",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested slice",
|
||||||
|
input: [][]int{{1, 2}, {3, 4}},
|
||||||
|
expected: "[[1,2],[3,4]]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex mixed structure",
|
||||||
|
input: struct {
|
||||||
|
ID int
|
||||||
|
Tags []string
|
||||||
|
Meta map[string]interface{}
|
||||||
|
Valid bool
|
||||||
|
}{
|
||||||
|
ID: 1,
|
||||||
|
Tags: []string{"a", "b"},
|
||||||
|
Meta: map[string]interface{}{"count": 42},
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
expected: "{ID:1 Tags:[a b] Meta:map[count:42] Valid:true}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := formatValue(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("formatValue() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,13 +2,13 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mitchellh/go-homedir"
|
"github.com/go-viper/mapstructure/v2"
|
||||||
"github.com/mitchellh/mapstructure"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
@ -144,7 +144,7 @@ func (cl *ConfigLoader) InitializeConfig(cmd *cobra.Command) error {
|
||||||
if cfgFile != "" {
|
if cfgFile != "" {
|
||||||
cl.v.SetConfigFile(cfgFile)
|
cl.v.SetConfigFile(cfgFile)
|
||||||
} else {
|
} else {
|
||||||
home, err := homedir.Dir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting home directory: %v", err)
|
return fmt.Errorf("error getting home directory: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,13 @@ package reader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
|
||||||
"github.com/tgdrive/teldrive/internal/cache"
|
"github.com/tgdrive/teldrive/internal/cache"
|
||||||
"github.com/tgdrive/teldrive/internal/config"
|
"github.com/tgdrive/teldrive/internal/config"
|
||||||
"github.com/tgdrive/teldrive/internal/crypt"
|
"github.com/tgdrive/teldrive/internal/crypt"
|
||||||
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
"github.com/tgdrive/teldrive/pkg/types"
|
"github.com/tgdrive/teldrive/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +19,7 @@ type Range struct {
|
||||||
|
|
||||||
type LinearReader struct {
|
type LinearReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
file *api.File
|
file *models.File
|
||||||
parts []types.Part
|
parts []types.Part
|
||||||
ranges []Range
|
ranges []Range
|
||||||
pos int
|
pos int
|
||||||
|
@ -52,7 +51,7 @@ func calculatePartByteRanges(start, end, partSize int64) []Range {
|
||||||
func NewLinearReader(ctx context.Context,
|
func NewLinearReader(ctx context.Context,
|
||||||
client *tg.Client,
|
client *tg.Client,
|
||||||
cache cache.Cacher,
|
cache cache.Cacher,
|
||||||
file *api.File,
|
file *models.File,
|
||||||
parts []types.Part,
|
parts []types.Part,
|
||||||
start,
|
start,
|
||||||
end int64,
|
end int64,
|
||||||
|
@ -61,7 +60,7 @@ func NewLinearReader(ctx context.Context,
|
||||||
) (io.ReadCloser, error) {
|
) (io.ReadCloser, error) {
|
||||||
|
|
||||||
size := parts[0].Size
|
size := parts[0].Size
|
||||||
if file.Encrypted.Value {
|
if file.Encrypted {
|
||||||
size = parts[0].DecryptedSize
|
size = parts[0].DecryptedSize
|
||||||
}
|
}
|
||||||
r := &LinearReader{
|
r := &LinearReader{
|
||||||
|
@ -129,22 +128,22 @@ func (r *LinearReader) moveToNextPart() error {
|
||||||
|
|
||||||
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
|
func (r *LinearReader) getPartReader() (io.ReadCloser, error) {
|
||||||
currentRange := r.ranges[r.pos]
|
currentRange := r.ranges[r.pos]
|
||||||
partID := r.parts[currentRange.PartNo].ID
|
partId := r.parts[currentRange.PartNo].ID
|
||||||
|
|
||||||
chunkSrc := &chunkSource{
|
chunkSrc := &chunkSource{
|
||||||
channelID: r.file.ChannelId.Value,
|
channelId: *r.file.ChannelId,
|
||||||
partID: partID,
|
partId: partId,
|
||||||
client: r.client,
|
client: r.client,
|
||||||
concurrency: r.concurrency,
|
concurrency: r.concurrency,
|
||||||
cache: r.cache,
|
cache: r.cache,
|
||||||
key: fmt.Sprintf("files:location:%s:%d", r.file.ID.Value, partID),
|
key: cache.Key("files", "location", r.file.ID, partId),
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
reader io.ReadCloser
|
reader io.ReadCloser
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if r.file.Encrypted.Value {
|
if r.file.Encrypted {
|
||||||
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
salt := r.parts[r.ranges[r.pos].PartNo].Salt
|
||||||
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
cipher, _ := crypt.NewCipher(r.config.Uploads.EncryptionKey, salt)
|
||||||
reader, err = cipher.DecryptDataSeek(r.ctx,
|
reader, err = cipher.DecryptDataSeek(r.ctx,
|
||||||
|
|
|
@ -25,8 +25,8 @@ type ChunkSource interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type chunkSource struct {
|
type chunkSource struct {
|
||||||
channelID int64
|
channelId int64
|
||||||
partID int64
|
partId int64
|
||||||
concurrency int
|
concurrency int
|
||||||
client *tg.Client
|
client *tg.Client
|
||||||
key string
|
key string
|
||||||
|
@ -46,7 +46,7 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b
|
||||||
err = c.cache.Get(c.key, location)
|
err = c.cache.Get(c.key, location)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
location, err = tgc.GetLocation(ctx, c.client, c.channelID, c.partID)
|
location, err = tgc.GetLocation(ctx, c.client, c.channelId, c.partId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,14 +5,13 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mitchellh/go-homedir"
|
|
||||||
"github.com/tgdrive/teldrive/internal/utils"
|
"github.com/tgdrive/teldrive/internal/utils"
|
||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewBoltDB(sessionFile string) (*bbolt.DB, error) {
|
func NewBoltDB(sessionFile string) (*bbolt.DB, error) {
|
||||||
if sessionFile == "" {
|
if sessionFile == "" {
|
||||||
dir, err := homedir.Dir()
|
dir, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dir = utils.ExecutableDir()
|
dir = utils.ExecutableDir()
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -12,13 +12,14 @@ import (
|
||||||
"github.com/gotd/td/telegram"
|
"github.com/gotd/td/telegram"
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/tgdrive/teldrive/internal/config"
|
"github.com/tgdrive/teldrive/internal/config"
|
||||||
|
"github.com/tgdrive/teldrive/internal/utils"
|
||||||
"github.com/tgdrive/teldrive/pkg/types"
|
"github.com/tgdrive/teldrive/pkg/types"
|
||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInValidChannelID = errors.New("invalid channel id")
|
ErrInValidChannelId = errors.New("invalid channel id")
|
||||||
ErrInvalidChannelMessages = errors.New("invalid channel messages")
|
ErrInvalidChannelMessages = errors.New("invalid channel messages")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*t
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(channels.GetChats()) == 0 {
|
if len(channels.GetChats()) == 0 {
|
||||||
return nil, ErrInValidChannelID
|
return nil, ErrInValidChannelId
|
||||||
}
|
}
|
||||||
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
|
return channels.GetChats()[0].(*tg.Channel).AsInput(), nil
|
||||||
}
|
}
|
||||||
|
@ -71,15 +72,11 @@ func DeleteMessages(ctx context.Context, client *telegram.Client, channelId int6
|
||||||
|
|
||||||
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
|
func getTGMessagesBatch(ctx context.Context, client *tg.Client, channel *tg.InputChannel, ids []int) (tg.MessagesMessagesClass, error) {
|
||||||
|
|
||||||
msgIds := []tg.InputMessageClass{}
|
|
||||||
|
|
||||||
for _, id := range ids {
|
|
||||||
msgIds = append(msgIds, &tg.InputMessageID{ID: id})
|
|
||||||
}
|
|
||||||
|
|
||||||
messageRequest := tg.ChannelsGetMessagesRequest{
|
messageRequest := tg.ChannelsGetMessagesRequest{
|
||||||
Channel: channel,
|
Channel: channel,
|
||||||
ID: msgIds,
|
ID: utils.Map(ids, func(id int) tg.InputMessageClass {
|
||||||
|
return &tg.InputMessageID{ID: id}
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
res, err := client.ChannelsGetMessages(ctx, &messageRequest)
|
||||||
|
|
|
@ -17,21 +17,21 @@ func NewBotWorker() *BotWorker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *BotWorker) Set(bots []string, channelID int64) {
|
func (w *BotWorker) Set(bots []string, channelId int64) {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
if _, ok := w.bots[channelID]; ok {
|
if _, ok := w.bots[channelId]; ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.bots[channelID] = bots
|
w.bots[channelId] = bots
|
||||||
w.currIdx[channelID] = 0
|
w.currIdx[channelId] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *BotWorker) Next(channelID int64) (string, int) {
|
func (w *BotWorker) Next(channelId int64) (string, int) {
|
||||||
w.mu.RLock()
|
w.mu.RLock()
|
||||||
defer w.mu.RUnlock()
|
defer w.mu.RUnlock()
|
||||||
bots := w.bots[channelID]
|
bots := w.bots[channelId]
|
||||||
index := w.currIdx[channelID]
|
index := w.currIdx[channelId]
|
||||||
w.currIdx[channelID] = (index + 1) % len(bots)
|
w.currIdx[channelId] = (index + 1) % len(bots)
|
||||||
return bots[index], index
|
return bots[index], index
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,73 +5,18 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"unicode"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func CamelToPascalCase(input string) string {
|
|
||||||
var result strings.Builder
|
|
||||||
upperNext := true
|
|
||||||
|
|
||||||
for _, char := range input {
|
|
||||||
if unicode.IsLetter(char) || unicode.IsDigit(char) {
|
|
||||||
if upperNext {
|
|
||||||
result.WriteRune(unicode.ToUpper(char))
|
|
||||||
upperNext = false
|
|
||||||
} else {
|
|
||||||
result.WriteRune(char)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
upperNext = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func CamelToSnake(input string) string {
|
func CamelToSnake(input string) string {
|
||||||
re := regexp.MustCompile("([a-z0-9])([A-Z])")
|
re := regexp.MustCompile("([a-z0-9])([A-Z])")
|
||||||
snake := re.ReplaceAllString(input, "${1}_${2}")
|
snake := re.ReplaceAllString(input, "${1}_${2}")
|
||||||
return strings.ToLower(snake)
|
return strings.ToLower(snake)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetField(v interface{}, field string) string {
|
|
||||||
r := reflect.ValueOf(v)
|
|
||||||
f := reflect.Indirect(r).FieldByName(field)
|
|
||||||
fieldValue := f.Interface()
|
|
||||||
|
|
||||||
switch v := fieldValue.(type) {
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case time.Time:
|
|
||||||
return v.Format(time.RFC3339)
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Ptr[T any](t T) *T {
|
func Ptr[T any](t T) *T {
|
||||||
return &t
|
return &t
|
||||||
}
|
}
|
||||||
|
|
||||||
func BoolPointer(b bool) *bool {
|
|
||||||
return &b
|
|
||||||
}
|
|
||||||
|
|
||||||
func IntPointer(b int) *int {
|
|
||||||
return &b
|
|
||||||
}
|
|
||||||
func Int64Pointer(b int64) *int64 {
|
|
||||||
return &b
|
|
||||||
}
|
|
||||||
|
|
||||||
func StringPointer(b string) *string {
|
|
||||||
return &b
|
|
||||||
}
|
|
||||||
|
|
||||||
func PathExists(path string) (bool, error) {
|
func PathExists(path string) (bool, error) {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -84,8 +29,34 @@ func PathExists(path string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExecutableDir() string {
|
func ExecutableDir() string {
|
||||||
|
|
||||||
path, _ := os.Executable()
|
path, _ := os.Executable()
|
||||||
|
|
||||||
return filepath.Dir(path)
|
return filepath.Dir(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Filter[T any](slice []T, predicate func(T) bool) []T {
|
||||||
|
var result []T
|
||||||
|
for _, v := range slice {
|
||||||
|
if predicate(v) {
|
||||||
|
result = append(result, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func Map[T any, U any](slice []T, mapper func(T) U) []U {
|
||||||
|
var result []U
|
||||||
|
for _, v := range slice {
|
||||||
|
result = append(result, mapper(v))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func Find[T any](slice []T, predicate func(T) bool) (T, bool) {
|
||||||
|
for _, v := range slice {
|
||||||
|
if predicate(v) {
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
|
@ -2,47 +2,41 @@ package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
|
"github.com/tgdrive/teldrive/internal/utils"
|
||||||
"github.com/tgdrive/teldrive/pkg/models"
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ToFileOut(file models.File, extended bool) *api.File {
|
func ToFileOut(file models.File) *api.File {
|
||||||
res := &api.File{
|
res := &api.File{
|
||||||
ID: api.NewOptString(file.Id),
|
ID: api.NewOptString(file.ID),
|
||||||
Name: file.Name,
|
Name: file.Name,
|
||||||
Type: api.FileType(file.Type),
|
Type: api.FileType(file.Type),
|
||||||
MimeType: api.NewOptString(file.MimeType),
|
MimeType: api.NewOptString(file.MimeType),
|
||||||
Encrypted: api.NewOptBool(file.Encrypted),
|
Encrypted: api.NewOptBool(file.Encrypted),
|
||||||
ParentId: api.NewOptString(file.ParentID.String),
|
|
||||||
UpdatedAt: api.NewOptDateTime(file.UpdatedAt),
|
UpdatedAt: api.NewOptDateTime(file.UpdatedAt),
|
||||||
}
|
}
|
||||||
|
if file.ParentId != nil {
|
||||||
|
res.ParentId = api.NewOptString(*file.ParentId)
|
||||||
|
}
|
||||||
if file.Size != nil {
|
if file.Size != nil {
|
||||||
res.Size = api.NewOptInt64(*file.Size)
|
res.Size = api.NewOptInt64(*file.Size)
|
||||||
}
|
}
|
||||||
if file.Category != "" {
|
if file.Category != "" {
|
||||||
res.Category = api.NewOptFileCategory(api.FileCategory(file.Category))
|
res.Category = api.NewOptFileCategory(api.FileCategory(file.Category))
|
||||||
}
|
}
|
||||||
if extended {
|
|
||||||
res.Parts = file.Parts
|
|
||||||
if file.ChannelID != nil {
|
|
||||||
res.ChannelId = api.NewOptInt64(*file.ChannelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToUploadOut(parts []models.Upload) []api.UploadPart {
|
func ToUploadOut(parts []models.Upload) []api.UploadPart {
|
||||||
res := []api.UploadPart{}
|
return utils.Map(parts, func(part models.Upload) api.UploadPart {
|
||||||
for _, part := range parts {
|
return api.UploadPart{
|
||||||
res = append(res, api.UploadPart{
|
|
||||||
Name: part.Name,
|
Name: part.Name,
|
||||||
PartId: part.PartId,
|
PartId: part.PartId,
|
||||||
ChannelId: part.ChannelID,
|
ChannelId: part.ChannelId,
|
||||||
PartNo: part.PartNo,
|
PartNo: part.PartNo,
|
||||||
Size: part.Size,
|
Size: part.Size,
|
||||||
Encrypted: part.Encrypted,
|
Encrypted: part.Encrypted,
|
||||||
Salt: api.NewOptString(part.Salt),
|
Salt: api.NewOptString(part.Salt),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,8 @@ package models
|
||||||
|
|
||||||
type Bot struct {
|
type Bot struct {
|
||||||
Token string `gorm:"type:text;primaryKey"`
|
Token string `gorm:"type:text;primaryKey"`
|
||||||
UserID int64 `gorm:"type:bigint"`
|
UserId int64 `gorm:"type:bigint"`
|
||||||
BotID int64 `gorm:"type:bigint"`
|
BotId int64 `gorm:"type:bigint"`
|
||||||
BotUserName string `gorm:"type:text"`
|
BotUserName string `gorm:"type:text"`
|
||||||
ChannelID int64 `gorm:"type:bigint"`
|
ChannelId int64 `gorm:"type:bigint"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package models
|
package models
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
ChannelID int64 `gorm:"type:bigint;primaryKey"`
|
ChannelId int64 `gorm:"type:bigint;primaryKey"`
|
||||||
ChannelName string `gorm:"type:text"`
|
ChannelName string `gorm:"type:text"`
|
||||||
UserID int64 `gorm:"type:bigint;"`
|
UserId int64 `gorm:"type:bigint;"`
|
||||||
Selected bool `gorm:"type:boolean;"`
|
Selected bool `gorm:"type:boolean;"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
|
@ -9,18 +8,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
Id string `gorm:"type:uuid;primaryKey;default:uuid7()"`
|
ID string `gorm:"type:uuid;primaryKey;default:uuid7()"`
|
||||||
Name string `gorm:"type:text;not null"`
|
Name string `gorm:"type:text;not null"`
|
||||||
Type string `gorm:"type:text;not null"`
|
Type string `gorm:"type:text;not null"`
|
||||||
MimeType string `gorm:"type:text;not null"`
|
MimeType string `gorm:"type:text;not null"`
|
||||||
Size *int64 `gorm:"type:bigint"`
|
Size *int64 `gorm:"type:bigint"`
|
||||||
Category string `gorm:"type:text"`
|
Category string `gorm:"type:text"`
|
||||||
Encrypted bool `gorm:"default:false"`
|
Encrypted bool `gorm:"default:false"`
|
||||||
UserID int64 `gorm:"type:bigint;not null"`
|
UserId int64 `gorm:"type:bigint;not null"`
|
||||||
Status string `gorm:"type:text"`
|
Status string `gorm:"type:text"`
|
||||||
ParentID sql.NullString `gorm:"type:uuid;index"`
|
ParentId *string `gorm:"type:uuid;index"`
|
||||||
Parts datatypes.JSONSlice[api.Part] `gorm:"type:jsonb"`
|
Parts datatypes.JSONSlice[api.Part] `gorm:"type:jsonb"`
|
||||||
ChannelID *int64 `gorm:"type:bigint"`
|
ChannelId *int64 `gorm:"type:bigint"`
|
||||||
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
||||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
|
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,10 +6,10 @@ import (
|
||||||
|
|
||||||
type FileShare struct {
|
type FileShare struct {
|
||||||
ID string `gorm:"type:uuid;default:uuid_generate_v4();primary_key"`
|
ID string `gorm:"type:uuid;default:uuid_generate_v4();primary_key"`
|
||||||
FileID string `gorm:"type:uuid;not null"`
|
FileId string `gorm:"type:uuid;not null"`
|
||||||
Password *string `gorm:"type:text"`
|
Password *string `gorm:"type:text"`
|
||||||
ExpiresAt *time.Time `gorm:"type:timestamp"`
|
ExpiresAt *time.Time `gorm:"type:timestamp"`
|
||||||
CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
||||||
UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp"`
|
||||||
UserID int64 `gorm:"type:bigint;not null"`
|
UserId int64 `gorm:"type:bigint;not null"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ type Upload struct {
|
||||||
PartId int `gorm:"type:integer"`
|
PartId int `gorm:"type:integer"`
|
||||||
Encrypted bool `gorm:"default:false"`
|
Encrypted bool `gorm:"default:false"`
|
||||||
Salt string `gorm:"type:text"`
|
Salt string `gorm:"type:text"`
|
||||||
ChannelID int64 `gorm:"type:bigint"`
|
ChannelId int64 `gorm:"type:bigint"`
|
||||||
Size int64 `gorm:"type:bigint"`
|
Size int64 `gorm:"type:bigint"`
|
||||||
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
CreatedAt time.Time `gorm:"default:timezone('utc'::text, now())"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/gotd/td/tgerr"
|
"github.com/gotd/td/tgerr"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
"github.com/tgdrive/teldrive/internal/auth"
|
"github.com/tgdrive/teldrive/internal/auth"
|
||||||
|
"github.com/tgdrive/teldrive/internal/cache"
|
||||||
"github.com/tgdrive/teldrive/internal/logging"
|
"github.com/tgdrive/teldrive/internal/logging"
|
||||||
"github.com/tgdrive/teldrive/internal/tgc"
|
"github.com/tgdrive/teldrive/internal/tgc"
|
||||||
"github.com/tgdrive/teldrive/pkg/models"
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
|
@ -80,7 +81,7 @@ func (a *apiService) AuthLogin(ctx context.Context, session *api.SessionCreate)
|
||||||
Name: "root",
|
Name: "root",
|
||||||
Type: "folder",
|
Type: "folder",
|
||||||
MimeType: "drive/folder",
|
MimeType: "drive/folder",
|
||||||
UserID: session.UserId,
|
UserId: session.UserId,
|
||||||
Status: "active",
|
Status: "active",
|
||||||
Parts: nil,
|
Parts: nil,
|
||||||
}
|
}
|
||||||
|
@ -129,7 +130,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{})
|
a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{})
|
||||||
a.cache.Delete(fmt.Sprintf("sessions:%s", authUser.Hash))
|
a.cache.Delete(cache.Key("sessions", authUser.Hash))
|
||||||
return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil
|
return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -360,7 +361,7 @@ func pack32BinaryIP4(ip4Address string) []byte {
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateTgSession(dcID int, authKey []byte, port int) string {
|
func generateTgSession(dcId int, authKey []byte, port int) string {
|
||||||
|
|
||||||
dcMaps := map[int]string{
|
dcMaps := map[int]string{
|
||||||
1: "149.154.175.53",
|
1: "149.154.175.53",
|
||||||
|
@ -370,8 +371,8 @@ func generateTgSession(dcID int, authKey []byte, port int) string {
|
||||||
5: "91.108.56.130",
|
5: "91.108.56.130",
|
||||||
}
|
}
|
||||||
|
|
||||||
dcIDByte := byte(dcID)
|
dcIDByte := byte(dcId)
|
||||||
serverAddressBytes := pack32BinaryIP4(dcMaps[dcID])
|
serverAddressBytes := pack32BinaryIP4(dcMaps[dcId])
|
||||||
portByte := make([]byte, 2)
|
portByte := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(portByte, uint16(port))
|
binary.BigEndian.PutUint16(portByte, uint16(port))
|
||||||
|
|
||||||
|
|
|
@ -2,106 +2,88 @@ package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-faster/errors"
|
|
||||||
"github.com/gotd/td/telegram"
|
"github.com/gotd/td/telegram"
|
||||||
"github.com/gotd/td/tg"
|
"github.com/gotd/td/tg"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
"github.com/tgdrive/teldrive/internal/cache"
|
"github.com/tgdrive/teldrive/internal/cache"
|
||||||
"github.com/tgdrive/teldrive/internal/crypt"
|
"github.com/tgdrive/teldrive/internal/crypt"
|
||||||
|
"github.com/tgdrive/teldrive/internal/logging"
|
||||||
"github.com/tgdrive/teldrive/internal/tgc"
|
"github.com/tgdrive/teldrive/internal/tgc"
|
||||||
|
"github.com/tgdrive/teldrive/internal/utils"
|
||||||
"github.com/tgdrive/teldrive/pkg/models"
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
"github.com/tgdrive/teldrive/pkg/types"
|
"github.com/tgdrive/teldrive/pkg/types"
|
||||||
|
"go.uber.org/zap"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getParts(ctx context.Context, client *telegram.Client, cache cache.Cacher, file *api.File) ([]types.Part, error) {
|
func getParts(ctx context.Context, client *telegram.Client, c cache.Cacher, file *models.File) ([]types.Part, error) {
|
||||||
|
return cache.Fetch(c, cache.Key("files", "messages", file.ID), 60*time.Minute, func() ([]types.Part, error) {
|
||||||
parts := []types.Part{}
|
messages, err := tgc.GetMessages(ctx, client.API(), utils.Map(file.Parts, func(part api.Part) int {
|
||||||
|
return part.ID
|
||||||
key := fmt.Sprintf("files:messages:%s", file.ID.Value)
|
}), *file.ChannelId)
|
||||||
|
|
||||||
err := cache.Get(key, &parts)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return parts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := []int{}
|
|
||||||
for _, part := range file.Parts {
|
|
||||||
ids = append(ids, int(part.ID))
|
|
||||||
}
|
|
||||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, file.ChannelId.Value)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
parts := []types.Part{}
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
item := message.(*tg.Message)
|
switch item := message.(type) {
|
||||||
media := item.Media.(*tg.MessageMediaDocument)
|
case *tg.Message:
|
||||||
document := media.Document.(*tg.Document)
|
media, ok := item.Media.(*tg.MessageMediaDocument)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
document, ok := media.Document.(*tg.Document)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
part := types.Part{
|
part := types.Part{
|
||||||
ID: int64(file.Parts[i].ID),
|
ID: int64(file.Parts[i].ID),
|
||||||
Size: document.Size,
|
Size: document.Size,
|
||||||
Salt: file.Parts[i].Salt.Value,
|
Salt: file.Parts[i].Salt.Value,
|
||||||
}
|
}
|
||||||
if file.Encrypted.Value {
|
if file.Encrypted {
|
||||||
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
part.DecryptedSize, _ = crypt.DecryptedSize(document.Size)
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
parts = append(parts, part)
|
||||||
}
|
}
|
||||||
cache.Set(key, &parts, 60*time.Minute)
|
}
|
||||||
|
if len(parts) != len(file.Parts) {
|
||||||
|
msg := "file parts mismatch"
|
||||||
|
logging.FromContext(ctx).Error(msg, zap.String("name", file.Name),
|
||||||
|
zap.Int("expected", len(file.Parts)), zap.Int("actual", len(parts)))
|
||||||
|
return nil, errors.New(msg)
|
||||||
|
}
|
||||||
return parts, nil
|
return parts, nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultChannel(db *gorm.DB, cache cache.Cacher, userID int64) (int64, error) {
|
func getDefaultChannel(db *gorm.DB, c cache.Cacher, userId int64) (int64, error) {
|
||||||
|
return cache.Fetch(c, cache.Key("users", "channel", userId), 0, func() (int64, error) {
|
||||||
var channelId int64
|
|
||||||
key := fmt.Sprintf("users:channel:%d", userID)
|
|
||||||
|
|
||||||
err := cache.Get(key, &channelId)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return channelId, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var channelIds []int64
|
var channelIds []int64
|
||||||
db.Model(&models.Channel{}).Where("user_id = ?", userID).Where("selected = ?", true).
|
if err := db.Model(&models.Channel{}).Where("user_id = ?", userId).Where("selected = ?", true).
|
||||||
Pluck("channel_id", &channelIds)
|
Pluck("channel_id", &channelIds).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
if len(channelIds) == 1 {
|
}
|
||||||
channelId = channelIds[0]
|
if len(channelIds) == 0 {
|
||||||
cache.Set(key, channelId, 0)
|
return 0, fmt.Errorf("no default channel found for user %d", userId)
|
||||||
|
}
|
||||||
|
return channelIds[0], nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if channelId == 0 {
|
func getBotsToken(db *gorm.DB, c cache.Cacher, userId, channelId int64) ([]string, error) {
|
||||||
return channelId, errors.New("default channel not set")
|
return cache.Fetch(c, cache.Key("users", "bots", userId, channelId), 0, func() ([]string, error) {
|
||||||
}
|
|
||||||
|
|
||||||
return channelId, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBotsToken(db *gorm.DB, cache cache.Cacher, userID, channelId int64) ([]string, error) {
|
|
||||||
var bots []string
|
var bots []string
|
||||||
|
if err := db.Model(&models.Bot{}).Where("user_id = ?", userId).
|
||||||
key := fmt.Sprintf("users:bots:%d:%d", userID, channelId)
|
|
||||||
|
|
||||||
err := cache.Get(key, &bots)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return bots, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Model(&models.Bot{}).Where("user_id = ?", userID).
|
|
||||||
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
|
Where("channel_id = ?", channelId).Pluck("token", &bots).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.Set(key, &bots, 0)
|
|
||||||
return bots, nil
|
return bots, nil
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package services
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -20,6 +19,7 @@ import (
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
"github.com/tgdrive/teldrive/internal/auth"
|
"github.com/tgdrive/teldrive/internal/auth"
|
||||||
|
"github.com/tgdrive/teldrive/internal/cache"
|
||||||
"github.com/tgdrive/teldrive/internal/category"
|
"github.com/tgdrive/teldrive/internal/category"
|
||||||
"github.com/tgdrive/teldrive/internal/database"
|
"github.com/tgdrive/teldrive/internal/database"
|
||||||
"github.com/tgdrive/teldrive/internal/http_range"
|
"github.com/tgdrive/teldrive/internal/http_range"
|
||||||
|
@ -71,7 +71,7 @@ func randInt64() (int64, error) {
|
||||||
b := &buffer{Buf: buf[:]}
|
b := &buffer{Buf: buf[:]}
|
||||||
return b.long()
|
return b.long()
|
||||||
}
|
}
|
||||||
func isUUID(str string) bool {
|
func isUUId(str string) bool {
|
||||||
_, err := uuid.Parse(str)
|
_, err := uuid.Parse(str)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
@ -97,10 +97,10 @@ func (a *apiService) getFileFromPath(path string, userId int64) (*models.File, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStats, error) {
|
func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStats, error) {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
var stats []api.CategoryStats
|
var stats []api.CategoryStats
|
||||||
if err := a.db.Model(&models.File{}).Select("category", "COUNT(*) as total_files", "coalesce(SUM(size),0) as total_size").
|
if err := a.db.Model(&models.File{}).Select("category", "COUNT(*) as total_files", "coalesce(SUM(size),0) as total_size").
|
||||||
Where(&models.File{UserID: userId, Type: "file", Status: "active"}).
|
Where(&models.File{UserId: userId, Type: "file", Status: "active"}).
|
||||||
Order("category ASC").Group("category").Find(&stats).Error; err != nil {
|
Order("category ASC").Group("category").Find(&stats).Error; err != nil {
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
@ -110,9 +110,9 @@ func (a *apiService) FilesCategoryStats(ctx context.Context) ([]api.CategoryStat
|
||||||
|
|
||||||
func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params api.FilesCopyParams) (*api.File, error) {
|
func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params api.FilesCopyParams) (*api.File, error) {
|
||||||
|
|
||||||
userId, session := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||||
|
|
||||||
var res []models.File
|
var res []models.File
|
||||||
|
|
||||||
|
@ -133,12 +133,9 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
err = tgc.RunWithAuth(ctx, client, "", func(ctx context.Context) error {
|
||||||
ids := []int{}
|
|
||||||
|
|
||||||
for _, part := range file.Parts {
|
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
|
||||||
ids = append(ids, int(part.ID))
|
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelId)
|
||||||
}
|
|
||||||
messages, err := tgc.GetMessages(ctx, client.API(), ids, *file.ChannelID)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -198,13 +195,13 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
||||||
}
|
}
|
||||||
|
|
||||||
var parentId string
|
var parentId string
|
||||||
if !isUUID(req.Destination) {
|
if !isUUId(req.Destination) {
|
||||||
var destRes []models.File
|
var destRes []models.File
|
||||||
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
|
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
|
||||||
Scan(&destRes).Error; err != nil {
|
Scan(&destRes).Error; err != nil {
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
parentId = destRes[0].Id
|
parentId = destRes[0].ID
|
||||||
} else {
|
} else {
|
||||||
parentId = req.Destination
|
parentId = req.Destination
|
||||||
}
|
}
|
||||||
|
@ -218,13 +215,10 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
||||||
if len(newIds) > 0 {
|
if len(newIds) > 0 {
|
||||||
dbFile.Parts = datatypes.NewJSONSlice(newIds)
|
dbFile.Parts = datatypes.NewJSONSlice(newIds)
|
||||||
}
|
}
|
||||||
dbFile.UserID = userId
|
dbFile.UserId = userId
|
||||||
dbFile.Status = "active"
|
dbFile.Status = "active"
|
||||||
dbFile.ParentID = sql.NullString{
|
dbFile.ParentId = utils.Ptr(parentId)
|
||||||
String: parentId,
|
dbFile.ChannelId = &channelId
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
dbFile.ChannelID = &channelId
|
|
||||||
dbFile.Encrypted = file.Encrypted
|
dbFile.Encrypted = file.Encrypted
|
||||||
dbFile.Category = string(file.Category)
|
dbFile.Category = string(file.Category)
|
||||||
if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
|
if req.UpdatedAt.IsSet() && !req.UpdatedAt.Value.IsZero() {
|
||||||
|
@ -237,11 +231,11 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mapper.ToFileOut(dbFile, false), nil
|
return mapper.ToFileOut(dbFile), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.File, error) {
|
func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.File, error) {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
fileDB models.File
|
fileDB models.File
|
||||||
|
@ -267,15 +261,9 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{err: err, code: 404}
|
return nil, &apiError{err: err, code: 404}
|
||||||
}
|
}
|
||||||
fileDB.ParentID = sql.NullString{
|
fileDB.ParentId = utils.Ptr(parent.ID)
|
||||||
String: parent.Id,
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
} else if fileIn.ParentId.Value != "" {
|
} else if fileIn.ParentId.Value != "" {
|
||||||
fileDB.ParentID = sql.NullString{
|
fileDB.ParentId = utils.Ptr(fileIn.ParentId.Value)
|
||||||
String: fileIn.ParentId.Value,
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,25 +279,17 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
||||||
} else {
|
} else {
|
||||||
channelId = fileIn.ChannelId.Value
|
channelId = fileIn.ChannelId.Value
|
||||||
}
|
}
|
||||||
fileDB.ChannelID = &channelId
|
fileDB.ChannelId = &channelId
|
||||||
fileDB.MimeType = fileIn.MimeType.Value
|
fileDB.MimeType = fileIn.MimeType.Value
|
||||||
fileDB.Category = string(category.GetCategory(fileIn.Name))
|
fileDB.Category = string(category.GetCategory(fileIn.Name))
|
||||||
if len(fileIn.Parts) > 0 {
|
if len(fileIn.Parts) > 0 {
|
||||||
parts := []api.Part{}
|
fileDB.Parts = datatypes.NewJSONSlice(mapParts(fileIn.Parts))
|
||||||
for _, part := range fileIn.Parts {
|
|
||||||
p := api.Part{ID: part.ID}
|
|
||||||
if part.Salt.Value != "" {
|
|
||||||
p.Salt = part.Salt
|
|
||||||
}
|
|
||||||
parts = append(parts, p)
|
|
||||||
}
|
|
||||||
fileDB.Parts = datatypes.NewJSONSlice(parts)
|
|
||||||
}
|
}
|
||||||
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
|
fileDB.Size = utils.Ptr(fileIn.Size.Or(0))
|
||||||
}
|
}
|
||||||
fileDB.Name = fileIn.Name
|
fileDB.Name = fileIn.Name
|
||||||
fileDB.Type = string(fileIn.Type)
|
fileDB.Type = string(fileIn.Type)
|
||||||
fileDB.UserID = userId
|
fileDB.UserId = userId
|
||||||
fileDB.Status = "active"
|
fileDB.Status = "active"
|
||||||
fileDB.Encrypted = fileIn.Encrypted.Value
|
fileDB.Encrypted = fileIn.Encrypted.Value
|
||||||
if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() {
|
if fileIn.UpdatedAt.IsSet() && !fileIn.UpdatedAt.Value.IsZero() {
|
||||||
|
@ -323,11 +303,11 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
|
||||||
}
|
}
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
return mapper.ToFileOut(fileDB, false), nil
|
return mapper.ToFileOut(fileDB), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCreate, params api.FilesCreateShareParams) error {
|
func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCreate, params api.FilesCreateShareParams) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
var fileShare models.FileShare
|
var fileShare models.FileShare
|
||||||
|
|
||||||
|
@ -339,11 +319,11 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
|
||||||
fileShare.Password = utils.Ptr(string(bytes))
|
fileShare.Password = utils.Ptr(string(bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
fileShare.FileID = params.ID
|
fileShare.FileId = params.ID
|
||||||
if req.ExpiresAt.IsSet() {
|
if req.ExpiresAt.IsSet() {
|
||||||
fileShare.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
fileShare.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
||||||
}
|
}
|
||||||
fileShare.UserID = userId
|
fileShare.UserId = userId
|
||||||
|
|
||||||
if err := a.db.Create(&fileShare).Error; err != nil {
|
if err := a.db.Create(&fileShare).Error; err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
|
@ -353,7 +333,7 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
|
func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
|
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
@ -362,7 +342,7 @@ func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDeleteShareParams) error {
|
func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDeleteShareParams) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
var deletedShare models.FileShare
|
var deletedShare models.FileShare
|
||||||
|
|
||||||
|
@ -371,14 +351,14 @@ func (a *apiService) FilesDeleteShare(ctx context.Context, params api.FilesDelet
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
if deletedShare.ID != "" {
|
if deletedShare.ID != "" {
|
||||||
a.cache.Delete(fmt.Sprintf("shares:%s", deletedShare.ID))
|
a.cache.Delete(cache.Key("shared", deletedShare.ID))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreate, params api.FilesEditShareParams) error {
|
func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreate, params api.FilesEditShareParams) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
var fileShareUpdate models.FileShare
|
var fileShareUpdate models.FileShare
|
||||||
|
|
||||||
|
@ -387,7 +367,7 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
fileShareUpdate.Password = utils.StringPointer(string(bytes))
|
fileShareUpdate.Password = utils.Ptr(string(bytes))
|
||||||
}
|
}
|
||||||
if req.ExpiresAt.IsSet() {
|
if req.ExpiresAt.IsSet() {
|
||||||
fileShareUpdate.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
fileShareUpdate.ExpiresAt = utils.Ptr(req.ExpiresAt.Value)
|
||||||
|
@ -403,26 +383,25 @@ func (a *apiService) FilesEditShare(ctx context.Context, req *api.FileShareCreat
|
||||||
|
|
||||||
func (a *apiService) FilesGetById(ctx context.Context, params api.FilesGetByIdParams) (*api.File, error) {
|
func (a *apiService) FilesGetById(ctx context.Context, params api.FilesGetByIdParams) (*api.File, error) {
|
||||||
var result []fullFileDB
|
var result []fullFileDB
|
||||||
notFoundResponse := &apiError{err: errors.New("file not found"), code: 404}
|
|
||||||
if err := a.db.Model(&models.File{}).Select("*",
|
if err := a.db.Model(&models.File{}).Select("*",
|
||||||
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(id))").
|
"(select get_path_from_file_id as path from teldrive.get_path_from_file_id(id))").
|
||||||
Where("id = ?", params.ID).Scan(&result).Error; err != nil {
|
Where("id = ?", params.ID).Scan(&result).Error; err != nil {
|
||||||
if database.IsRecordNotFoundErr(err) {
|
|
||||||
return nil, notFoundResponse
|
|
||||||
}
|
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
return nil, notFoundResponse
|
return nil, &apiError{err: errors.New("file not found"), code: 404}
|
||||||
}
|
}
|
||||||
res := mapper.ToFileOut(result[0].File, true)
|
res := mapper.ToFileOut(result[0].File)
|
||||||
res.Path = api.NewOptString(result[0].Path)
|
res.Path = api.NewOptString(result[0].Path)
|
||||||
|
if result[0].ChannelId != nil {
|
||||||
|
res.ChannelId = api.NewOptInt64(*result[0].ChannelId)
|
||||||
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) (*api.FileList, error) {
|
func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams) (*api.FileList, error) {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
queryBuilder := &fileQueryBuilder{db: a.db}
|
queryBuilder := &fileQueryBuilder{db: a.db}
|
||||||
|
|
||||||
|
@ -430,7 +409,7 @@ func (a *apiService) FilesList(ctx context.Context, params api.FilesListParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
if err := a.db.Exec("select * from teldrive.create_directories(?, ?)", userId, req.Path).Error; err != nil {
|
if err := a.db.Exec("select * from teldrive.create_directories(?, ?)", userId, req.Path).Error; err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
|
@ -439,18 +418,18 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
items := pgtype.Array[string]{
|
items := pgtype.Array[string]{
|
||||||
Elements: req.Ids,
|
Elements: req.Ids,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
|
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
|
||||||
}
|
}
|
||||||
if !isUUID(req.Destination) {
|
if !isUUId(req.Destination) {
|
||||||
r, err := a.getFileFromPath(req.Destination, userId)
|
r, err := a.getFileFromPath(req.Destination, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
req.Destination = r.Id
|
req.Destination = r.ID
|
||||||
}
|
}
|
||||||
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
|
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
|
||||||
Update("parent_id", req.Destination).Error; err != nil {
|
Update("parent_id", req.Destination).Error; err != nil {
|
||||||
|
@ -462,7 +441,7 @@ func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesShareByid(ctx context.Context, params api.FilesShareByidParams) (*api.FileShare, error) {
|
func (a *apiService) FilesShareByid(ctx context.Context, params api.FilesShareByidParams) (*api.FileShare, error) {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
var result []models.FileShare
|
var result []models.FileShare
|
||||||
|
|
||||||
notFoundErr := &apiError{err: errors.New("invalid share"), code: 404}
|
notFoundErr := &apiError{err: errors.New("invalid share"), code: 404}
|
||||||
|
@ -500,15 +479,7 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
|
||||||
updateDb.Name = req.Name.Value
|
updateDb.Name = req.Name.Value
|
||||||
}
|
}
|
||||||
if len(req.Parts) > 0 {
|
if len(req.Parts) > 0 {
|
||||||
parts := []api.Part{}
|
updateDb.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
|
||||||
for _, part := range req.Parts {
|
|
||||||
p := api.Part{ID: part.ID}
|
|
||||||
if part.Salt.Value != "" {
|
|
||||||
p.Salt = part.Salt
|
|
||||||
}
|
|
||||||
parts = append(parts, p)
|
|
||||||
}
|
|
||||||
updateDb.Parts = datatypes.NewJSONSlice(parts)
|
|
||||||
}
|
}
|
||||||
if req.Size.Value != 0 {
|
if req.Size.Value != 0 {
|
||||||
updateDb.Size = utils.Ptr(req.Size.Value)
|
updateDb.Size = utils.Ptr(req.Size.Value)
|
||||||
|
@ -523,17 +494,18 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
|
a.cache.Delete(cache.Key("files", params.ID))
|
||||||
|
|
||||||
file := models.File{}
|
file := models.File{}
|
||||||
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
|
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
return mapper.ToFileOut(file, false), nil
|
return mapper.ToFileOut(file), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
|
func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
|
||||||
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
var file models.File
|
var file models.File
|
||||||
|
|
||||||
|
@ -545,29 +517,18 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
updatePayload.ChannelID = &channelId
|
updatePayload.ChannelId = &channelId
|
||||||
} else {
|
} else {
|
||||||
updatePayload.ChannelID = &req.ChannelId.Value
|
updatePayload.ChannelId = &req.ChannelId.Value
|
||||||
}
|
}
|
||||||
if len(req.Parts) > 0 {
|
if len(req.Parts) > 0 {
|
||||||
parts := []api.Part{}
|
updatePayload.Parts = datatypes.NewJSONSlice(mapParts(req.Parts))
|
||||||
for _, part := range req.Parts {
|
|
||||||
p := api.Part{ID: part.ID}
|
|
||||||
if part.Salt.Value != "" {
|
|
||||||
p.Salt = part.Salt
|
|
||||||
}
|
|
||||||
parts = append(parts, p)
|
|
||||||
}
|
|
||||||
updatePayload.Parts = datatypes.NewJSONSlice(parts)
|
|
||||||
}
|
}
|
||||||
if req.Name.Value != "" {
|
if req.Name.Value != "" {
|
||||||
updatePayload.Name = req.Name.Value
|
updatePayload.Name = req.Name.Value
|
||||||
}
|
}
|
||||||
if req.ParentId.Value != "" {
|
if req.ParentId.Value != "" {
|
||||||
updatePayload.ParentID = sql.NullString{
|
updatePayload.ParentId = utils.Ptr(req.ParentId.Value)
|
||||||
String: req.ParentId.Value,
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updatePayload.UpdatedAt = req.UpdatedAt
|
updatePayload.UpdatedAt = req.UpdatedAt
|
||||||
|
@ -592,28 +553,23 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(file.Parts) > 0 && file.ChannelID != nil {
|
keys := []string{cache.Key("files", params.ID)}
|
||||||
_, session := auth.GetUser(ctx)
|
if len(file.Parts) > 0 && file.ChannelId != nil {
|
||||||
ids := []int{}
|
ids := utils.Map(file.Parts, func(part api.Part) int { return part.ID })
|
||||||
|
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||||
|
tgc.DeleteMessages(ctx, client, *file.ChannelId, ids)
|
||||||
|
keys = append(keys, cache.Key("files", "messages", params.ID))
|
||||||
for _, part := range file.Parts {
|
for _, part := range file.Parts {
|
||||||
ids = append(ids, int(part.ID))
|
keys = append(keys, cache.Key("files", "location", params.ID, part.ID))
|
||||||
}
|
}
|
||||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
|
||||||
tgc.DeleteMessages(ctx, client, *file.ChannelID, ids)
|
|
||||||
keys := []string{fmt.Sprintf("files:%s", params.ID), fmt.Sprintf("files:messages:%s", params.ID)}
|
|
||||||
for _, part := range file.Parts {
|
|
||||||
keys = append(keys, fmt.Sprintf("files:location:%s:%d", params.ID, part.ID))
|
|
||||||
|
|
||||||
}
|
}
|
||||||
a.cache.Delete(keys...)
|
a.cache.Delete(keys...)
|
||||||
|
|
||||||
}
|
|
||||||
a.cache.Delete(fmt.Sprintf("files:%s", params.ID))
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileID string, userId int64) {
|
func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fileId string, userId int64) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
var (
|
var (
|
||||||
session *models.Session
|
session *models.Session
|
||||||
|
@ -646,19 +602,17 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
session = &models.Session{UserId: userId}
|
session = &models.Session{UserId: userId}
|
||||||
}
|
}
|
||||||
|
|
||||||
file := &api.File{}
|
file, err := cache.Fetch(e.api.cache, cache.Key("files", fileId), 0, func() (*models.File, error) {
|
||||||
|
var result models.File
|
||||||
key := fmt.Sprintf("files:%s", fileID)
|
if err := e.api.db.Model(&result).Where("id = ?", fileId).First(&result).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
err = e.api.cache.Get(key, file)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
file, err = e.api.FilesGetById(ctx, api.FilesGetByIdParams{ID: fileID})
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
e.api.cache.Set(key, file, 0)
|
return &result, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Accept-Ranges", "bytes")
|
w.Header().Set("Accept-Ranges", "bytes")
|
||||||
|
@ -666,16 +620,15 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
var start, end int64
|
var start, end int64
|
||||||
|
|
||||||
rangeHeader := r.Header.Get("Range")
|
rangeHeader := r.Header.Get("Range")
|
||||||
|
contentType := defaultContentType
|
||||||
|
|
||||||
if file.Size.Value == 0 {
|
if file.MimeType != "" {
|
||||||
w.Header().Set("Content-Type", file.MimeType.Or(defaultContentType))
|
contentType = file.MimeType
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
|
|
||||||
if rangeHeader != "" {
|
|
||||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
|
|
||||||
http.Error(w, "Requested Range Not Satisfiable", http.StatusRequestedRangeNotSatisfiable)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if file.Size == nil || *file.Size == 0 {
|
||||||
|
w.Header().Set("Content-Type", contentType)
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": file.Name}))
|
w.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": file.Name}))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
|
@ -684,11 +637,11 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
status := http.StatusOK
|
status := http.StatusOK
|
||||||
if rangeHeader == "" {
|
if rangeHeader == "" {
|
||||||
start = 0
|
start = 0
|
||||||
end = file.Size.Value - 1
|
end = *file.Size - 1
|
||||||
} else {
|
} else {
|
||||||
ranges, err := http_range.Parse(rangeHeader, file.Size.Value)
|
ranges, err := http_range.Parse(rangeHeader, *file.Size)
|
||||||
if err == http_range.ErrNoOverlap {
|
if err == http_range.ErrNoOverlap {
|
||||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", file.Size.Value))
|
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", *file.Size))
|
||||||
http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable)
|
http.Error(w, http_range.ErrNoOverlap.Error(), http.StatusRequestedRangeNotSatisfiable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -702,20 +655,18 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
}
|
}
|
||||||
start = ranges[0].Start
|
start = ranges[0].Start
|
||||||
end = ranges[0].End
|
end = ranges[0].End
|
||||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, file.Size.Value))
|
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, *file.Size))
|
||||||
status = http.StatusPartialContent
|
status = http.StatusPartialContent
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
contentLength := end - start + 1
|
contentLength := end - start + 1
|
||||||
|
|
||||||
mimeType := file.MimeType.Or(defaultContentType)
|
w.Header().Set("Content-Type", contentType)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", mimeType)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
||||||
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileID+strconv.FormatInt(file.Size.Value, 10))))
|
w.Header().Set("E-Tag", fmt.Sprintf("\"%s\"", md5.FromString(fileId+strconv.FormatInt(*file.Size, 10))))
|
||||||
w.Header().Set("Last-Modified", file.UpdatedAt.Value.UTC().Format(http.TimeFormat))
|
w.Header().Set("Last-Modified", file.UpdatedAt.UTC().Format(http.TimeFormat))
|
||||||
|
|
||||||
disposition := "inline"
|
disposition := "inline"
|
||||||
|
|
||||||
|
@ -733,7 +684,7 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, file.ChannelId.Value)
|
tokens, err := getBotsToken(e.api.db, e.api.cache, session.UserId, *file.ChannelId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to get bots", http.StatusInternalServerError)
|
http.Error(w, "failed to get bots", http.StatusInternalServerError)
|
||||||
|
@ -761,9 +712,9 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
|
||||||
multiThreads = 0
|
multiThreads = 0
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
e.api.worker.Set(tokens, file.ChannelId.Value)
|
e.api.worker.Set(tokens, *file.ChannelId)
|
||||||
|
|
||||||
token, _ = e.api.worker.Next(file.ChannelId.Value)
|
token, _ = e.api.worker.Next(*file.ChannelId)
|
||||||
|
|
||||||
client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...)
|
client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -816,5 +767,16 @@ func (e *extendedService) SharesStream(w http.ResponseWriter, r *http.Request, s
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
e.FilesStream(w, r, fileId, share.UserID)
|
e.FilesStream(w, r, fileId, share.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapParts(_parts []api.Part) []api.Part {
|
||||||
|
return utils.Map(_parts, func(part api.Part) api.Part {
|
||||||
|
p := api.Part{ID: part.ID}
|
||||||
|
if part.Salt.Value != "" {
|
||||||
|
p.Salt = part.Salt
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
|
|
||||||
type fileQueryBuilder struct {
|
type fileQueryBuilder struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
selectAllFields bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type fileResponse struct {
|
type fileResponse struct {
|
||||||
|
@ -26,7 +25,7 @@ type fileResponse struct {
|
||||||
Total int
|
Total int
|
||||||
}
|
}
|
||||||
|
|
||||||
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "encrypted", "size", "parent_id", "updated_at"}
|
var selectedFields = []string{"id", "name", "type", "mime_type", "category", "channel_id", "encrypted", "size", "parent_id", "updated_at"}
|
||||||
|
|
||||||
func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int64) (*api.FileList, error) {
|
func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int64) (*api.FileList, error) {
|
||||||
query := afb.db.Where("user_id = ?", userId).Where("status = ?", filesQuery.Status.Value)
|
query := afb.db.Where("user_id = ?", userId).Where("status = ?", filesQuery.Status.Value)
|
||||||
|
@ -50,11 +49,7 @@ func (afb *fileQueryBuilder) execute(filesQuery *api.FilesListParams, userId int
|
||||||
count = res[0].Total
|
count = res[0].Total
|
||||||
}
|
}
|
||||||
|
|
||||||
files := []api.File{}
|
files := utils.Map(res, func(item fileResponse) api.File { return *mapper.ToFileOut(item.File) })
|
||||||
|
|
||||||
for _, file := range res {
|
|
||||||
files = append(files, *mapper.ToFileOut(file.File, afb.selectAllFields))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &api.FileList{Items: files,
|
return &api.FileList{Items: files,
|
||||||
Meta: api.FileListMeta{Count: count,
|
Meta: api.FileListMeta{Count: count,
|
||||||
|
@ -192,15 +187,10 @@ func (afb *fileQueryBuilder) buildFileQuery(query *gorm.DB, filesQuery *api.File
|
||||||
orderField := utils.CamelToSnake(string(filesQuery.Sort.Value))
|
orderField := utils.CamelToSnake(string(filesQuery.Sort.Value))
|
||||||
op := getOrderOperation(filesQuery)
|
op := getOrderOperation(filesQuery)
|
||||||
|
|
||||||
fields := selectedFields
|
|
||||||
if afb.selectAllFields {
|
|
||||||
fields = append(fields, "parts", "channel_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return afb.buildSubqueryCTE(query, filesQuery, userId).Clauses(exclause.NewWith("ranked_scores", afb.db.Model(&models.File{}).Select(orderField, "count(*) OVER () as total",
|
return afb.buildSubqueryCTE(query, filesQuery, userId).Clauses(exclause.NewWith("ranked_scores", afb.db.Model(&models.File{}).Select(orderField, "count(*) OVER () as total",
|
||||||
fmt.Sprintf("ROW_NUMBER() OVER (ORDER BY %s %s) AS rank", orderField, strings.ToUpper(string(filesQuery.Order.Value)))).
|
fmt.Sprintf("ROW_NUMBER() OVER (ORDER BY %s %s) AS rank", orderField, strings.ToUpper(string(filesQuery.Order.Value)))).
|
||||||
Where(query))).Model(&models.File{}).
|
Where(query))).Model(&models.File{}).
|
||||||
Select(fields, "(select total from ranked_scores limit 1) as total").
|
Select(selectedFields, "(select total from ranked_scores limit 1) as total").
|
||||||
Where(fmt.Sprintf("%s %s (SELECT %s FROM ranked_scores WHERE rank = ?)", orderField, op, orderField),
|
Where(fmt.Sprintf("%s %s (SELECT %s FROM ranked_scores WHERE rank = ?)", orderField, op, orderField),
|
||||||
max((filesQuery.Page.Value-1)*filesQuery.Limit.Value, 1)).
|
max((filesQuery.Page.Value-1)*filesQuery.Limit.Value, 1)).
|
||||||
Where(query).Order(getOrder(filesQuery)).Limit(filesQuery.Limit.Value)
|
Where(query).Order(getOrder(filesQuery)).Limit(filesQuery.Limit.Value)
|
||||||
|
@ -223,13 +213,6 @@ func getOrder(filesQuery *api.FilesListParams) string {
|
||||||
return fmt.Sprintf("%s %s", orderField, strings.ToUpper(string(filesQuery.Order.Value)))
|
return fmt.Sprintf("%s %s", orderField, strings.ToUpper(string(filesQuery.Order.Value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func max(x int, y int) int {
|
|
||||||
if x > y {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
return y
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOrderOperation(filesQuery *api.FilesListParams) string {
|
func getOrderOperation(filesQuery *api.FilesListParams) string {
|
||||||
if filesQuery.Page.Value == 1 {
|
if filesQuery.Page.Value == 1 {
|
||||||
if filesQuery.Order.Value == api.FileQueryOrderAsc {
|
if filesQuery.Order.Value == api.FileQueryOrderAsc {
|
||||||
|
|
|
@ -60,7 +60,7 @@ func (a *apiService) SharesGetById(ctx context.Context, params api.SharesGetById
|
||||||
}
|
}
|
||||||
res := &api.FileShareInfo{
|
res := &api.FileShareInfo{
|
||||||
Protected: share.Password != nil,
|
Protected: share.Password != nil,
|
||||||
UserId: share.UserID,
|
UserId: share.UserId,
|
||||||
Type: share.Type,
|
Type: share.Type,
|
||||||
Name: share.Name,
|
Name: share.Name,
|
||||||
}
|
}
|
||||||
|
@ -104,16 +104,16 @@ func (a *apiService) SharesListFiles(ctx context.Context, params api.SharesListF
|
||||||
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
|
Status: api.NewOptFileQueryStatus(api.FileQueryStatusActive),
|
||||||
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
|
Order: api.NewOptFileQueryOrder(api.FileQueryOrder(string(params.Order.Value))),
|
||||||
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
|
Sort: api.NewOptFileQuerySort(api.FileQuerySort(string(params.Sort.Value))),
|
||||||
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserID)
|
Operation: api.NewOptFileQueryOperation(api.FileQueryOperationList)}, share.UserId)
|
||||||
} else {
|
} else {
|
||||||
var file models.File
|
var file models.File
|
||||||
if err := a.db.Where("id = ?", share.FileID).First(&file).Error; err != nil {
|
if err := a.db.Where("id = ?", share.FileId).First(&file).Error; err != nil {
|
||||||
if database.IsRecordNotFoundErr(err) {
|
if database.IsRecordNotFoundErr(err) {
|
||||||
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
|
return nil, &apiError{err: database.ErrNotFound, code: http.StatusNotFound}
|
||||||
}
|
}
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
}
|
}
|
||||||
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file, false)},
|
return &api.FileList{Items: []api.File{*mapper.ToFileOut(file)},
|
||||||
Meta: api.FileListMeta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
|
Meta: api.FileListMeta{Count: 1, TotalPages: 1, CurrentPage: 1}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ func (a *apiService) UploadsPartsById(ctx context.Context, params api.UploadsPar
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
|
func (a *apiService) UploadsStats(ctx context.Context, params api.UploadsStatsParams) ([]api.UploadStats, error) {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
var stats []api.UploadStats
|
var stats []api.UploadStats
|
||||||
err := a.db.Raw(`
|
err := a.db.Raw(`
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -94,7 +94,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
||||||
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
|
return nil, &apiError{err: errors.New("encryption is not enabled"), code: 400}
|
||||||
}
|
}
|
||||||
|
|
||||||
userId, session := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
fileStream := req.Content.Data
|
fileStream := req.Content.Data
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tokens) == 0 {
|
if len(tokens) == 0 {
|
||||||
client, err = tgc.AuthClient(ctx, &a.cnf.TG, session)
|
client, err = tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -220,7 +220,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
||||||
Name: params.PartName,
|
Name: params.PartName,
|
||||||
UploadId: params.ID,
|
UploadId: params.ID,
|
||||||
PartId: message.ID,
|
PartId: message.ID,
|
||||||
ChannelID: channelId,
|
ChannelId: channelId,
|
||||||
Size: fileSize,
|
Size: fileSize,
|
||||||
PartNo: int(params.PartNo),
|
PartNo: int(params.PartNo),
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
|
@ -244,7 +244,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
|
||||||
out = api.UploadPart{
|
out = api.UploadPart{
|
||||||
Name: partUpload.Name,
|
Name: partUpload.Name,
|
||||||
PartId: partUpload.PartId,
|
PartId: partUpload.PartId,
|
||||||
ChannelId: partUpload.ChannelID,
|
ChannelId: partUpload.ChannelId,
|
||||||
PartNo: partUpload.PartNo,
|
PartNo: partUpload.PartNo,
|
||||||
Size: partUpload.Size,
|
Size: partUpload.Size,
|
||||||
Encrypted: partUpload.Encrypted,
|
Encrypted: partUpload.Encrypted,
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/gotd/td/tgerr"
|
"github.com/gotd/td/tgerr"
|
||||||
"github.com/tgdrive/teldrive/internal/api"
|
"github.com/tgdrive/teldrive/internal/api"
|
||||||
"github.com/tgdrive/teldrive/internal/auth"
|
"github.com/tgdrive/teldrive/internal/auth"
|
||||||
|
"github.com/tgdrive/teldrive/internal/cache"
|
||||||
"github.com/tgdrive/teldrive/internal/tgc"
|
"github.com/tgdrive/teldrive/internal/tgc"
|
||||||
"github.com/tgdrive/teldrive/pkg/models"
|
"github.com/tgdrive/teldrive/pkg/models"
|
||||||
"github.com/tgdrive/teldrive/pkg/types"
|
"github.com/tgdrive/teldrive/pkg/types"
|
||||||
|
@ -27,8 +28,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
|
func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
|
||||||
userId, session := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||||
|
|
||||||
if len(req.Bots) > 0 {
|
if len(req.Bots) > 0 {
|
||||||
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
||||||
|
@ -47,11 +48,11 @@ func (a *apiService) UsersAddBots(ctx context.Context, req *api.AddBots) error {
|
||||||
|
|
||||||
func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, error) {
|
func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, error) {
|
||||||
|
|
||||||
userID, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
channels := make(map[int64]*api.Channel)
|
channels := make(map[int64]*api.Channel)
|
||||||
|
|
||||||
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userID)))
|
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
|
||||||
|
|
||||||
iter, err := peerStorage.Iterate(ctx)
|
iter, err := peerStorage.Iterate(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -67,6 +68,7 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res := []api.Channel{}
|
res := []api.Channel{}
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
res = append(res, *channel)
|
res = append(res, *channel)
|
||||||
|
@ -79,10 +81,10 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersSyncChannels(ctx context.Context) error {
|
func (a *apiService) UsersSyncChannels(ctx context.Context) error {
|
||||||
userId, session := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
|
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
|
||||||
collector := storage.CollectPeers(peerStorage)
|
collector := storage.CollectPeers(peerStorage)
|
||||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
@ -96,7 +98,9 @@ func (a *apiService) UsersSyncChannels(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) {
|
func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) {
|
||||||
userId, userSession := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
|
userSession := auth.GetJWTUser(ctx).TgSession
|
||||||
|
|
||||||
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
|
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
|
||||||
|
|
||||||
|
@ -150,9 +154,8 @@ func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) {
|
func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) {
|
||||||
_, session := auth.GetUser(ctx)
|
|
||||||
|
|
||||||
client, err := tgc.AuthClient(ctx, &a.cnf.TG, session, a.middlewares...)
|
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
|
@ -194,25 +197,25 @@ func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersRemoveBots(ctx context.Context) error {
|
func (a *apiService) UsersRemoveBots(ctx context.Context) error {
|
||||||
userID, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
channelId, err := getDefaultChannel(a.db, a.cache, userID)
|
channelId, err := getDefaultChannel(a.db, a.cache, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.db.Where("user_id = ?", userID).Where("channel_id = ?", channelId).
|
if err := a.db.Where("user_id = ?", userId).Where("channel_id = ?", channelId).
|
||||||
Delete(&models.Bot{}).Error; err != nil {
|
Delete(&models.Bot{}).Error; err != nil {
|
||||||
return &apiError{err: err}
|
return &apiError{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId))
|
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRemoveSessionParams) error {
|
func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRemoveSessionParams) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
session := &models.Session{}
|
session := &models.Session{}
|
||||||
|
|
||||||
|
@ -236,15 +239,18 @@ func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
|
func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
|
||||||
userID, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
var (
|
var (
|
||||||
channelId int64
|
channelId int64
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
channelId, _ = getDefaultChannel(a.db, a.cache, userID)
|
channelId, err = getDefaultChannel(a.db, a.cache, userId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &apiError{err: err}
|
||||||
|
}
|
||||||
|
|
||||||
tokens, err := getBotsToken(a.db, a.cache, userID, channelId)
|
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &apiError{err: err}
|
return nil, &apiError{err: err}
|
||||||
|
@ -253,12 +259,12 @@ func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpdate) error {
|
func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpdate) error {
|
||||||
userId, _ := auth.GetUser(ctx)
|
userId := auth.GetUser(ctx)
|
||||||
|
|
||||||
channel := &models.Channel{UserID: userId, Selected: true}
|
channel := &models.Channel{UserId: userId, Selected: true}
|
||||||
|
|
||||||
if req.ChannelId.Value != 0 {
|
if req.ChannelId.Value != 0 {
|
||||||
channel.ChannelID = req.ChannelId.Value
|
channel.ChannelId = req.ChannelId.Value
|
||||||
}
|
}
|
||||||
if req.ChannelName.Value != "" {
|
if req.ChannelName.Value != "" {
|
||||||
channel.ChannelName = req.ChannelName.Value
|
channel.ChannelName = req.ChannelName.Value
|
||||||
|
@ -270,11 +276,10 @@ func (a *apiService) UsersUpdateChannel(ctx context.Context, req *api.ChannelUpd
|
||||||
}).Create(channel).Error; err != nil {
|
}).Create(channel).Error; err != nil {
|
||||||
return &apiError{err: errors.New("failed to update channel")}
|
return &apiError{err: errors.New("failed to update channel")}
|
||||||
}
|
}
|
||||||
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelID).
|
a.db.Model(&models.Channel{}).Where("channel_id != ?", channel.ChannelId).
|
||||||
Where("user_id = ?", userId).Update("selected", false)
|
Where("user_id = ?", userId).Update("selected", false)
|
||||||
|
|
||||||
key := fmt.Sprintf("users:channel:%d", userId)
|
a.cache.Set(cache.Key("users", "channel", userId), channel.ChannelId, 0)
|
||||||
a.cache.Set(key, channel.ChannelID, 0)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -359,12 +364,12 @@ func (a *apiService) addBots(c context.Context, client *telegram.Client, userId
|
||||||
payload := []models.Bot{}
|
payload := []models.Bot{}
|
||||||
|
|
||||||
for _, info := range botInfoMap {
|
for _, info := range botInfoMap {
|
||||||
payload = append(payload, models.Bot{UserID: userId, Token: info.Token, BotID: info.Id,
|
payload = append(payload, models.Bot{UserId: userId, Token: info.Token, BotId: info.Id,
|
||||||
BotUserName: info.UserName, ChannelID: channelId,
|
BotUserName: info.UserName, ChannelId: channelId,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
a.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId))
|
a.cache.Delete(cache.Key("users", "bots", userId, channelId))
|
||||||
|
|
||||||
if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
|
if err := a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Add table
Reference in a new issue