refactor: migrate session db from bbolt to sqlite

This commit is contained in:
Bhunter 2025-01-15 06:21:08 +01:00
parent 9573db7ea7
commit 0d51fdfe0a
20 changed files with 374 additions and 112 deletions

View file

@ -28,6 +28,9 @@ builds:
- amd64
- arm
- arm64
ignore:
- goos: windows
goarch: arm
checksum:
name_template: "{{ .ProjectName }}_checksums.txt"

View file

@ -25,8 +25,8 @@ import (
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/middleware"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/internal/tgstorage"
"github.com/tgdrive/teldrive/ui"
"go.etcd.io/bbolt"
"github.com/tgdrive/teldrive/pkg/cron"
"github.com/tgdrive/teldrive/pkg/services"
@ -177,15 +177,19 @@ func runApplication(ctx context.Context, conf *config.ServerCmdConfig) {
lg.Fatalw("failed to migrate database", "err", err)
}
boltDb, err := tgc.NewBoltDB(conf.TG.SessionFile)
tgdb, err := tgstorage.NewDatabase(conf.TG.StorageFile)
if err != nil {
lg.Fatalw("failed to create bolt db", "err", err)
lg.Fatalw("failed to create tg db", "err", err)
}
err = tgstorage.MigrateDB(tgdb)
if err != nil {
lg.Fatalw("failed to migrate tg db", "err", err)
}
worker := tgc.NewBotWorker()
srv := setupServer(conf, db, cacher, boltDb, worker)
srv := setupServer(conf, db, cacher, tgdb, worker)
cron.StartCronJobs(scheduler, db, conf)
@ -213,11 +217,11 @@ func runApplication(ctx context.Context, conf *config.ServerCmdConfig) {
lg.Info("Server stopped")
}
func setupServer(cfg *config.ServerCmdConfig, db *gorm.DB, cache cache.Cacher, boltdb *bbolt.DB, worker *tgc.BotWorker) *http.Server {
func setupServer(cfg *config.ServerCmdConfig, db *gorm.DB, cache cache.Cacher, tgdb *gorm.DB, worker *tgc.BotWorker) *http.Server {
lg := logging.DefaultLogger()
apiSrv := services.NewApiService(db, cfg, cache, boltdb, worker)
apiSrv := services.NewApiService(db, cfg, cache, tgdb, worker)
srv, err := api.NewServer(apiSrv, auth.NewSecurityHandler(db, cache, &cfg.JWT))

View file

@ -7,7 +7,7 @@ services:
- postgres
volumes:
- ./config.toml:/config.toml
- ./session.db:/session.db
- ./storage.db:/storage.db
ports:
- 8080:8080
networks:

10
go.mod
View file

@ -6,6 +6,7 @@ require (
github.com/Masterminds/semver/v3 v3.3.1
github.com/WinterYukky/gorm-extra-clause-plugin v0.3.0
github.com/coocood/freecache v1.2.4
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.0
github.com/go-chi/cors v1.2.1
github.com/go-co-op/gocron v1.37.0
@ -21,7 +22,6 @@ require (
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
github.com/vmihailenco/msgpack/v5 v5.4.1
go.etcd.io/bbolt v1.3.11
go.uber.org/zap v1.27.0
golang.org/x/time v0.9.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
@ -38,9 +38,11 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/fsnotify/fsnotify v1.8.0 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/glebarez/go-sqlite v1.22.0 // indirect
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
@ -51,7 +53,9 @@ require (
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
@ -68,6 +72,10 @@ require (
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gorm.io/driver/mysql v1.5.7 // indirect
modernc.org/libc v1.55.3 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/sqlite v1.34.1 // indirect
)

22
go.sum
View file

@ -59,6 +59,10 @@ github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK
github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-chi/chi/v5 v5.2.0 h1:Aj1EtB0qR2Rdo2dG4O94RIU35w2lvQSj6BRA4+qwFL0=
github.com/go-chi/chi/v5 v5.2.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
@ -95,6 +99,8 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@ -225,8 +231,6 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0=
go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw=
@ -258,8 +262,6 @@ golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
@ -295,6 +297,14 @@ gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ=
modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ=
modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y=
modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s=
modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE=
modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ=
modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw=
modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI=
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4=
modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U=
@ -303,6 +313,10 @@ modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E=
modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc=
modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss=
modernc.org/sqlite v1.34.1 h1:u3Yi6M0N8t9yKRDwhXcyp1eS5/ErhPTBggxWFuR6Hfk=
modernc.org/sqlite v1.34.1/go.mod h1:pXV2xHxhzXZsgT/RtTFAPY6JJDEvOTcTdwADQCCWD4k=
modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA=

View file

@ -1,4 +1,4 @@
FROM scratch
COPY teldrive /teldrive
EXPOSE 8080
ENTRYPOINT ["/teldrive","run","--tg-session-file","/session.db"]
ENTRYPOINT ["/teldrive","run","--tg-storage-file","/storage.db"]

View file

@ -74,7 +74,7 @@ type TGConfig struct {
SystemLangCode string `mapstructure:"system-lang-code"`
LangPack string `mapstructure:"lang-pack"`
Ntp bool `mapstructure:"ntp"`
SessionFile string `mapstructure:"session-file"`
StorageFile string `mapstructure:"storage-file"`
DisableStreamBots bool `mapstructure:"disable-stream-bots"`
Proxy string `mapstructure:"proxy"`
ReconnectTimeout time.Duration `mapstructure:"reconnect-timeout"`
@ -211,7 +211,7 @@ func AddCommonFlags(flags *pflag.FlagSet, config *ServerCmdConfig) {
// Telegram config
flags.IntVar(&config.TG.AppId, "tg-app-id", 0, "Telegram app ID")
flags.StringVar(&config.TG.AppHash, "tg-app-hash", "", "Telegram app hash")
flags.StringVar(&config.TG.SessionFile, "tg-session-file", "", "Bot session file path")
flags.StringVar(&config.TG.StorageFile, "tg-storage-file", "", "Sqlite Storage file path")
flags.BoolVar(&config.TG.RateLimit, "tg-rate-limit", true, "Enable rate limiting for telegram client")
flags.IntVar(&config.TG.RateBurst, "tg-rate-burst", 5, "Limiting burst for telegram client")
flags.IntVar(&config.TG.Rate, "tg-rate", 100, "Limiting rate for telegram client")

View file

@ -1,35 +0,0 @@
package tgc
import (
"os"
"path/filepath"
"time"
"github.com/tgdrive/teldrive/internal/utils"
"go.etcd.io/bbolt"
)
func NewBoltDB(sessionFile string) (*bbolt.DB, error) {
if sessionFile == "" {
dir, err := os.UserHomeDir()
if err != nil {
dir = utils.ExecutableDir()
} else {
dir = filepath.Join(dir, ".teldrive")
err := os.Mkdir(dir, 0755)
if err != nil && !os.IsExist(err) {
dir = utils.ExecutableDir()
}
}
sessionFile = filepath.Join(dir, "session.db")
}
db, err := bbolt.Open(sessionFile, 0666, &bbolt.Options{
Timeout: time.Second,
NoGrowSync: false,
})
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -14,8 +14,8 @@ import (
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/utils"
"github.com/tgdrive/teldrive/pkg/types"
"go.etcd.io/bbolt"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
)
var (
@ -28,7 +28,6 @@ func GetChannelById(ctx context.Context, client *tg.Client, channelId int64) (*t
ChannelID: channelId,
}
channels, err := client.ChannelsGetChannels(ctx, []tg.InputChannelClass{inputChannel})
if err != nil {
return nil, err
}
@ -181,10 +180,10 @@ func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFi
return buff, nil
}
func GetBotInfo(ctx context.Context, boltdb *bbolt.DB, config *config.TGConfig, token string) (*types.BotInfo, error) {
func GetBotInfo(ctx context.Context, db *gorm.DB, config *config.TGConfig, token string) (*types.BotInfo, error) {
var user *tg.User
middlewares := NewMiddleware(config, WithFloodWait(), WithRateLimit())
client, _ := BotClient(ctx, boltdb, config, token, middlewares...)
client, _ := BotClient(ctx, db, config, token, middlewares...)
err := RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil

View file

@ -7,28 +7,25 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/go-faster/errors"
tgbbolt "github.com/gotd/contrib/bbolt"
"github.com/gotd/contrib/clock"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/gotd/contrib/middleware/ratelimit"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/dcs"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/logging"
"github.com/tgdrive/teldrive/internal/recovery"
"github.com/tgdrive/teldrive/internal/retry"
"github.com/tgdrive/teldrive/internal/tgstorage"
"github.com/tgdrive/teldrive/internal/utils"
"go.etcd.io/bbolt"
"go.uber.org/zap"
"golang.org/x/net/proxy"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
func sessionKey(indexes ...string) string {
return strings.Join(indexes, ":")
}
func newClient(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHandler, storage session.Storage, middlewares ...telegram.Middleware) (*telegram.Client, error) {
var dialer dcs.DialFunc = proxy.Direct.DialContext
@ -107,9 +104,9 @@ func AuthClient(ctx context.Context, config *config.TGConfig, sessionStr string,
return newClient(ctx, config, nil, storage, middlewares...)
}
func BotClient(ctx context.Context, boltdb *bbolt.DB, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
func BotClient(ctx context.Context, db *gorm.DB, config *config.TGConfig, token string, middlewares ...telegram.Middleware) (*telegram.Client, error) {
storage := tgbbolt.NewSessionStorage(boltdb, sessionKey("botsession", token), []byte("teldrive"))
storage := tgstorage.NewSessionStorage(db, cache.Key("sessions", strings.Split(token, ":")[0]))
return newClient(ctx, config, nil, storage, middlewares...)

46
internal/tgstorage/db.go Normal file
View file

@ -0,0 +1,46 @@
package tgstorage
import (
"os"
"path/filepath"
"time"
"github.com/glebarez/sqlite"
"github.com/go-faster/errors"
"github.com/tgdrive/teldrive/internal/utils"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func NewDatabase(storageFile string) (*gorm.DB, error) {
if storageFile == "" {
dir, err := os.UserHomeDir()
if err != nil {
dir = utils.ExecutableDir()
} else {
dir = filepath.Join(dir, ".teldrive")
err := os.Mkdir(dir, 0755)
if err != nil && !os.IsExist(err) {
dir = utils.ExecutableDir()
}
}
storageFile = filepath.Join(dir, "storage.db")
}
db, err := gorm.Open(sqlite.Open(storageFile), &gorm.Config{NowFunc: func() time.Time {
return time.Now().UTC()
}, Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
return nil, err
}
return db, nil
}
func MigrateDB(db *gorm.DB) error {
if err := db.AutoMigrate(&KeyValue{}); err != nil {
return errors.Wrap(err, "auto migrate")
}
return nil
}

46
internal/tgstorage/kv.go Normal file
View file

@ -0,0 +1,46 @@
package tgstorage
import (
"context"
"github.com/go-faster/errors"
"gorm.io/gorm"
"github.com/gotd/contrib/auth/kv"
)
type KeyValue struct {
Key string `gorm:"primaryKey;column:key"`
Value []byte `gorm:"not null;column:value;type:blob"`
}
func (KeyValue) TableName() string {
return "kv"
}
type kvStorage struct {
db *gorm.DB
}
func (s kvStorage) Set(ctx context.Context, k, v string) error {
return s.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&KeyValue{
Key: k,
Value: []byte(v),
}).Error; err != nil {
return errors.Wrap(err, "save value")
}
return nil
})
}
func (s kvStorage) Get(ctx context.Context, k string) (string, error) {
var entry KeyValue
if err := s.db.First(&entry, "key = ?", k).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", kv.ErrKeyNotFound
}
return "", errors.Wrap(err, "query")
}
return string(entry.Value), nil
}

View file

@ -0,0 +1,149 @@
package tgstorage
import (
"context"
"database/sql"
"encoding/json"
"github.com/go-faster/errors"
"github.com/gotd/contrib/storage"
"github.com/tgdrive/teldrive/internal/cache"
"gorm.io/gorm"
)
var _ storage.PeerStorage = PeerStorage{}
type PeerStorage struct {
db *gorm.DB
prefix string
}
func NewPeerStorage(db *gorm.DB, prefix string) *PeerStorage {
return &PeerStorage{
db: db,
prefix: prefix,
}
}
type sqliteIterator struct {
rows *sql.Rows
value storage.Peer
}
func (p *sqliteIterator) Close() error {
return p.rows.Close()
}
func (p *sqliteIterator) Next(ctx context.Context) bool {
if !p.rows.Next() {
return false
}
var val []byte
if err := p.rows.Scan(&val); err != nil {
return false
}
if err := json.Unmarshal(val, &p.value); err != nil {
if errors.Is(err, storage.ErrPeerUnmarshalMustInvalidate) {
return p.Next(ctx)
}
return false
}
return true
}
func (p *sqliteIterator) Err() error {
return p.rows.Err()
}
func (p *sqliteIterator) Value() storage.Peer {
return p.value
}
func (s PeerStorage) Iterate(ctx context.Context) (storage.PeerIterator, error) {
rows, err := s.db.Model(&KeyValue{}).
Select("value").
Where("key LIKE ?", s.prefix+"%").
Rows()
if err != nil {
return nil, errors.Wrap(err, "query")
}
return &sqliteIterator{rows: rows}, nil
}
func (s PeerStorage) add(associated []string, value storage.Peer) error {
return s.db.Transaction(func(tx *gorm.DB) error {
data, err := json.Marshal(value)
if err != nil {
return errors.Wrap(err, "marshal")
}
if err := tx.Save(&KeyValue{
Key: cache.Key(s.prefix, storage.KeyFromPeer(value).String()),
Value: data,
}).Error; err != nil {
return errors.Wrap(err, "save peer")
}
for _, key := range associated {
if err := tx.Save(&KeyValue{
Key: cache.Key(s.prefix, key),
Value: data,
}).Error; err != nil {
return errors.Wrap(err, "save associated key")
}
}
return nil
})
}
func (s PeerStorage) Add(ctx context.Context, value storage.Peer) error {
return s.add(value.Keys(), value)
}
func (s PeerStorage) Find(ctx context.Context, key storage.PeerKey) (storage.Peer, error) {
var entry KeyValue
if err := s.db.First(&entry, "key = ?", cache.Key(s.prefix, key.String())).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return storage.Peer{}, storage.ErrPeerNotFound
}
return storage.Peer{}, errors.Wrap(err, "query")
}
var p storage.Peer
if err := json.Unmarshal([]byte(entry.Value), &p); err != nil {
if errors.Is(err, storage.ErrPeerUnmarshalMustInvalidate) {
return storage.Peer{}, storage.ErrPeerNotFound
}
return storage.Peer{}, errors.Wrap(err, "unmarshal")
}
return p, nil
}
func (s PeerStorage) Assign(ctx context.Context, key string, value storage.Peer) error {
return s.add(append(value.Keys(), key), value)
}
func (s PeerStorage) Resolve(ctx context.Context, key string) (storage.Peer, error) {
var entry KeyValue
if err := s.db.First(&entry, "key = ?", cache.Key(s.prefix, key)).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return storage.Peer{}, storage.ErrPeerNotFound
}
return storage.Peer{}, errors.Wrap(err, "query")
}
var p storage.Peer
if err := json.Unmarshal([]byte(entry.Value), &p); err != nil {
if errors.Is(err, storage.ErrPeerUnmarshalMustInvalidate) {
return storage.Peer{}, storage.ErrPeerNotFound
}
return storage.Peer{}, errors.Wrap(err, "unmarshal")
}
return p, nil
}

View file

@ -0,0 +1,23 @@
package tgstorage
import (
"github.com/gotd/td/session"
"gorm.io/gorm"
"github.com/gotd/contrib/auth/kv"
)
var _ session.Storage = SessionStorage{}
type SessionStorage struct {
kv.Session
}
func NewSessionStorage(db *gorm.DB, key string) SessionStorage {
s := &kvStorage{
db: db,
}
return SessionStorage{
Session: kv.NewSession(s, key),
}
}

View file

@ -7,7 +7,6 @@ import (
"github.com/go-faster/errors"
"github.com/gotd/td/telegram"
"github.com/ogen-go/ogen/ogenerrors"
"go.etcd.io/bbolt"
ht "github.com/ogen-go/ogen/http"
"github.com/tgdrive/teldrive/internal/api"
@ -22,7 +21,7 @@ type apiService struct {
db *gorm.DB
cnf *config.ServerCmdConfig
cache cache.Cacher
boltdb *bbolt.DB
tgdb *gorm.DB
worker *tgc.BotWorker
middlewares []telegram.Middleware
}
@ -55,9 +54,9 @@ func (a *apiService) NewError(ctx context.Context, err error) *api.ErrorStatusCo
func NewApiService(db *gorm.DB,
cnf *config.ServerCmdConfig,
cache cache.Cacher,
boltdb *bbolt.DB,
tgdb *gorm.DB,
worker *tgc.BotWorker) *apiService {
return &apiService{db: db, cnf: cnf, cache: cache, boltdb: boltdb, worker: worker,
return &apiService{db: db, cnf: cnf, cache: cache, tgdb: tgdb, worker: worker,
middlewares: tgc.NewMiddleware(&cnf.TG, tgc.WithFloodWait(), tgc.WithRateLimit())}
}

View file

@ -130,7 +130,7 @@ func (a *apiService) AuthLogout(ctx context.Context) (*api.AuthLogoutNoContent,
return err
})
a.db.Where("hash = ?", authUser.Hash).Delete(&models.Session{})
a.cache.Delete(cache.Key("sessions", authUser.Hash))
a.cache.Delete(cache.Key("sessions", authUser.Hash), cache.Key("users", "sessions", authUser.ID))
return &api.AuthLogoutNoContent{SetCookie: setCookie(authCookieName, "", -1)}, nil
}

View file

@ -716,7 +716,7 @@ func (e *extendedService) FilesStream(w http.ResponseWriter, r *http.Request, fi
token, _ = e.api.worker.Next(*file.ChannelId)
client, err = tgc.BotClient(ctx, e.api.boltdb, &e.api.cnf.TG, token, middlewares...)
client, err = tgc.BotClient(ctx, e.api.tgdb, &e.api.cnf.TG, token, middlewares...)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return

View file

@ -124,7 +124,7 @@ func (a *apiService) UploadsUpload(ctx context.Context, req *api.UploadsUploadRe
} else {
a.worker.Set(tokens, channelId)
token, index = a.worker.Next(channelId)
client, err = tgc.BotClient(ctx, a.boltdb, &a.cnf.TG, token)
client, err = tgc.BotClient(ctx, a.tgdb, &a.cnf.TG, token)
if err != nil {
return nil, err

View file

@ -18,11 +18,11 @@ import (
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/tgc"
"github.com/tgdrive/teldrive/internal/tgstorage"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
"golang.org/x/sync/errgroup"
tgbbolt "github.com/gotd/contrib/bbolt"
"github.com/gotd/contrib/storage"
"gorm.io/gorm/clause"
)
@ -52,12 +52,13 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
channels := make(map[int64]*api.Channel)
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
peerStorage := tgstorage.NewPeerStorage(a.tgdb, cache.Key("peers", userId))
iter, err := peerStorage.Iterate(ctx)
if err != nil {
return []api.Channel{}, nil
}
defer iter.Close()
for iter.Next(ctx) {
peer := iter.Value()
if peer.Channel != nil && peer.Channel.AdminRights.AddAdmins {
@ -82,7 +83,7 @@ func (a *apiService) UsersListChannels(ctx context.Context) ([]api.Channel, erro
func (a *apiService) UsersSyncChannels(ctx context.Context) error {
userId := auth.GetUser(ctx)
peerStorage := tgbbolt.NewPeerStorage(a.boltdb, []byte(fmt.Sprintf("peers:%d", userId)))
peerStorage := tgstorage.NewPeerStorage(a.tgdb, cache.Key("peers", userId))
collector := storage.CollectPeers(peerStorage)
client, err := tgc.AuthClient(ctx, &a.cnf.TG, auth.GetJWTUser(ctx).TgSession, a.middlewares...)
if err != nil {
@ -99,58 +100,62 @@ func (a *apiService) UsersSyncChannels(ctx context.Context) error {
func (a *apiService) UsersListSessions(ctx context.Context) ([]api.UserSession, error) {
userId := auth.GetUser(ctx)
return cache.Fetch(a.cache, cache.Key("users", "sessions", userId), 0, func() ([]api.UserSession, error) {
userSession := auth.GetJWTUser(ctx).TgSession
userSession := auth.GetJWTUser(ctx).TgSession
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
client, _ := tgc.AuthClient(ctx, &a.cnf.TG, userSession, a.middlewares...)
var (
auth *tg.AccountAuthorizations
err error
)
var (
auth *tg.AccountAuthorizations
err error
)
err = client.Run(ctx, func(ctx context.Context) error {
auth, err = client.API().AccountGetAuthorizations(ctx)
if err != nil {
return err
err = client.Run(ctx, func(ctx context.Context) error {
auth, err = client.API().AccountGetAuthorizations(ctx)
if err != nil {
return err
}
return nil
})
if err != nil && !tgerr.Is(err, "AUTH_KEY_UNREGISTERED") {
return nil, err
}
return nil
})
if err != nil && !tgerr.Is(err, "AUTH_KEY_UNREGISTERED") {
return nil, err
}
dbSessions := []models.Session{}
dbSessions := []models.Session{}
if err = a.db.Where("user_id = ?", userId).Order("created_at DESC").Find(&dbSessions).Error; err != nil {
return nil, err
}
if err = a.db.Where("user_id = ?", userId).Order("created_at DESC").Find(&dbSessions).Error; err != nil {
return nil, err
}
sessionsOut := []api.UserSession{}
sessionsOut := []api.UserSession{}
for _, session := range dbSessions {
for _, session := range dbSessions {
s := api.UserSession{Hash: session.Hash,
CreatedAt: session.CreatedAt.UTC(),
Current: session.Session == userSession}
s := api.UserSession{Hash: session.Hash,
CreatedAt: session.CreatedAt.UTC(),
Current: session.Session == userSession}
if auth != nil {
for _, a := range auth.Authorizations {
if session.SessionDate == a.DateCreated {
s.AppName = api.NewOptString(strings.Trim(strings.Replace(a.AppName, "Telegram", "", -1), " "))
s.Location = api.NewOptString(a.Country)
s.OfficialApp = api.NewOptBool(a.OfficialApp)
s.Valid = true
break
if auth != nil {
for _, a := range auth.Authorizations {
if session.SessionDate == a.DateCreated {
s.AppName = api.NewOptString(strings.Trim(strings.Replace(a.AppName, "Telegram", "", -1), " "))
s.Location = api.NewOptString(a.Country)
s.OfficialApp = api.NewOptBool(a.OfficialApp)
s.Valid = true
break
}
}
}
sessionsOut = append(sessionsOut, s)
}
sessionsOut = append(sessionsOut, s)
}
return sessionsOut, nil
})
return sessionsOut, nil
}
func (a *apiService) UsersProfileImage(ctx context.Context, params api.UsersProfileImageParams) (*api.UsersProfileImageOKHeaders, error) {
@ -234,6 +239,7 @@ func (a *apiService) UsersRemoveSession(ctx context.Context, params api.UsersRem
})
a.db.Where("user_id = ?", userId).Where("hash = ?", session.Hash).Delete(&models.Session{})
a.cache.Delete(cache.Key("users", "sessions", userId))
return nil
}
@ -247,13 +253,13 @@ func (a *apiService) UsersStats(ctx context.Context) (*api.UserConfig, error) {
channelId, err = getDefaultChannel(a.db, a.cache, userId)
if err != nil {
return nil, &apiError{err: err}
channelId = 0
}
tokens, err := getBotsToken(a.db, a.cache, userId, channelId)
if err != nil {
return nil, &apiError{err: err}
tokens = []string{}
}
return &api.UserConfig{Bots: tokens, ChannelId: channelId}, nil
}
@ -303,7 +309,7 @@ func (a *apiService) addBots(c context.Context, client *telegram.Client, userId
for _, token := range botsTokens {
g.Go(func() error {
info, err := tgc.GetBotInfo(c, a.boltdb, &a.cnf.TG, token)
info, err := tgc.GetBotInfo(c, a.tgdb, &a.cnf.TG, token)
if err != nil {
return err
}

View file

@ -15,6 +15,9 @@ vars:
sh: go env GOARCH
BINARY_EXTENSION: '{{if eq OS "windows"}}.exe{{end}}'
env:
CGO_ENABLED: 0
tasks:
default:
cmds: