diff --git a/api/router.go b/api/router.go index 99fa3f6..3179fe9 100644 --- a/api/router.go +++ b/api/router.go @@ -10,7 +10,7 @@ import ( "gorm.io/gorm" ) -func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config, db *gorm.DB, cache *cache.Cache) *gin.Engine { +func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config, db *gorm.DB, cache cache.Cacher) *gin.Engine { authmiddleware := middleware.Authmiddleware(cnf.JWT.Secret, db, cache) api := r.Group("/api") { diff --git a/cmd/run.go b/cmd/run.go index 70ca5bb..b91c8a5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -60,6 +60,14 @@ func NewRun() *cobra.Command { duration.DurationVar(runCmd.Flags(), &config.Server.WriteTimeout, "server-write-timeout", 1*time.Hour, "Server write timeout") runCmd.Flags().BoolVar(&config.CronJobs.Enable, "cronjobs-enable", true, "Run cron jobs") + duration.DurationVar(runCmd.Flags(), &config.CronJobs.CleanFilesInterval, "cronjobs-clean-files-interval", 1*time.Hour, "Clean files interval") + duration.DurationVar(runCmd.Flags(), &config.CronJobs.CleanUploadsInterval, "cronjobs-clean-uploads-interval", 12*time.Hour, "Clean uploads interval") + duration.DurationVar(runCmd.Flags(), &config.CronJobs.FolderSizeInterval, "cronjobs-folder-size-interval", 2*time.Hour, "Folder size update interval") + + runCmd.Flags().StringVar(&config.Cache.Type, "cache-type", "memory", "Cache type redis or memory") + runCmd.Flags().IntVar(&config.Cache.MaxSize, "cache-max-size", 10*1024*1024, "Max Cache max size (memory)") + runCmd.Flags().StringVar(&config.Cache.RedisAddr, "cache-redis-addr", "localhost:6379", "Redis address") + runCmd.Flags().StringVar(&config.Cache.RedisPass, "cache-redis-pass", "", "Redis password") runCmd.Flags().IntVarP(&config.Log.Level, "log-level", "", -1, "Logging level") runCmd.Flags().StringVar(&config.Log.File, "log-file", "", "Logging file path") @@ -121,7 +129,7 @@ func runApplication(conf *config.Config) { FilePath: conf.Log.File, }) - tgContext, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) defer func() { logging.DefaultLogger().Sync() @@ -130,17 +138,21 @@ func runApplication(conf *config.Config) { scheduler := gocron.NewScheduler(time.UTC) + cacher := cache.NewCache(ctx, conf) + app := fx.New( fx.Supply(conf), fx.Supply(scheduler), + fx.Provide(func() cache.Cacher { + return cacher + }), fx.Supply(logging.DefaultLogger().Desugar()), fx.NopLogger, fx.StopTimeout(conf.Server.GracefulShutdown+time.Second), fx.Provide( database.NewDatabase, - cache.DefaultCache, kv.NewBoltKV, - tgc.NewStreamWorker(tgContext), + tgc.NewStreamWorker(ctx), tgc.NewUploadWorker, services.NewAuthService, services.NewFileService, @@ -226,7 +238,7 @@ func modifyFlag(s string) string { return string(result) } -func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller, db *gorm.DB, cache *cache.Cache) *gin.Engine { +func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller, db *gorm.DB, cache cache.Cacher) *gin.Engine { gin.SetMode(gin.ReleaseMode) diff --git a/go.mod b/go.mod index d34f8c3..dd6b09d 100644 --- a/go.mod +++ b/go.mod @@ -17,10 +17,11 @@ require ( github.com/magiconair/properties v1.8.7 github.com/mitchellh/go-homedir v1.1.0 github.com/pkg/errors v0.9.1 + github.com/redis/go-redis/v9 v9.6.1 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 - github.com/vmihailenco/msgpack v4.0.4+incompatible + github.com/vmihailenco/msgpack/v5 v5.4.1 go.etcd.io/bbolt v1.3.10 go.uber.org/fx v1.22.1 go.uber.org/zap v1.27.0 @@ -38,9 +39,9 @@ require ( github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -57,9 +58,9 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect go.uber.org/dig v1.17.1 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - google.golang.org/appengine v1.6.8 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gorm.io/driver/mysql v1.5.6 // indirect diff --git a/go.sum b/go.sum index c90286e..0736905 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7Oputl github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/WinterYukky/gorm-extra-clause-plugin v0.2.1 h1:G0e4eFRrh3WdM1I3EKKidV2yF5J09uRIJlKYxt6zNR4= github.com/WinterYukky/gorm-extra-clause-plugin v0.2.1/go.mod h1:qAN5KRJJTCM49X2wUHZAVB3rfvO8A8L0ISd/uB1WM5s= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= @@ -25,6 +29,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/divyam234/cors v1.4.2 h1:moAxStmYpvG9/SkPz+Wld02iutgo3JcUvrez6Kit/D8= github.com/divyam234/cors v1.4.2/go.mod h1:JrxBJAqTU7jtPItodwf2mzxbbZm0Qq0NFkK8jo9UUDk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= @@ -77,11 +83,6 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -169,6 +170,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pressly/goose/v3 v3.21.1 h1:5SSAKKWej8LVVzNLuT6KIvP1eFDuPvxa+B6H0w78buQ= github.com/pressly/goose/v3 v3.21.1/go.mod h1:sqthmzV8PitchEkjecFJII//l43dLOCzfWh8pHEe+vE= +github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= +github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -217,9 +220,10 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= -github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= @@ -242,51 +246,23 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 9e47f49..6d4f2c3 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -43,7 +43,7 @@ func GetUser(c *gin.Context) (int64, string) { return userId, jwtUser.TgSession } -func VerifyUser(c *gin.Context, db *gorm.DB, cache *cache.Cache, secret string) (*types.JWTClaims, error) { +func VerifyUser(c *gin.Context, db *gorm.DB, cache cache.Cacher, secret string) (*types.JWTClaims, error) { var token string cookie, err := c.Request.Cookie("user-session") @@ -77,7 +77,7 @@ func VerifyUser(c *gin.Context, db *gorm.DB, cache *cache.Cache, secret string) return claims, nil } -func GetSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) { +func GetSessionByHash(db *gorm.DB, cache cache.Cacher, hash string) (*models.Session, error) { var session models.Session key := fmt.Sprintf("sessions:%s", hash) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 66c79bd..a774222 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -2,99 +2,106 @@ package cache import ( "context" - "sync" "time" "github.com/coocood/freecache" - "github.com/gin-gonic/gin" - "github.com/vmihailenco/msgpack" + "github.com/divyam234/teldrive/internal/config" + "github.com/redis/go-redis/v9" + "github.com/vmihailenco/msgpack/v5" ) -type Cache struct { - cache *freecache.Cache - mu sync.RWMutex +type Cacher interface { + Get(key string, value interface{}) error + Set(key string, value interface{}, expiration time.Duration) error + Delete(keys ...string) error } -func (c *Cache) Get(key string, value interface{}) error { - c.mu.RLock() - defer c.mu.RUnlock() - result, err := c.cache.Get([]byte(key)) +type MemoryCache struct { + cache *freecache.Cache + prefix string +} + +func NewCache(ctx context.Context, conf *config.Config) Cacher { + var cacher Cacher + switch conf.Cache.Type { + case "memory": + cacher = NewMemoryCache(conf.Cache.MaxSize) + case "redis": + cacher = NewRedisCache(ctx, redis.NewClient(&redis.Options{ + Addr: conf.Cache.RedisAddr, + Password: conf.Cache.RedisPass, + })) + } + return cacher +} + +func NewMemoryCache(size int) *MemoryCache { + return &MemoryCache{ + cache: freecache.NewCache(size), + prefix: "teldrive:", + } +} + +func (m *MemoryCache) Get(key string, value interface{}) error { + key = m.prefix + key + data, err := m.cache.Get([]byte(key)) if err != nil { return err } + return msgpack.Unmarshal(data, value) +} - err = msgpack.Unmarshal(result, value) - +func (m *MemoryCache) Set(key string, value interface{}, expiration time.Duration) error { + key = m.prefix + key + data, err := msgpack.Marshal(value) if err != nil { return err } + return m.cache.Set([]byte(key), data, int(expiration.Seconds())) +} + +func (m *MemoryCache) Delete(keys ...string) error { + for _, key := range keys { + m.cache.Del([]byte(m.prefix + key)) + } return nil } -func (c *Cache) Set(key string, value interface{}, expires time.Duration) error { - c.mu.Lock() - defer c.mu.Unlock() - bytes, err := msgpack.Marshal(value) +type RedisCache struct { + client *redis.Client + ctx context.Context + prefix string +} + +func NewRedisCache(ctx context.Context, client *redis.Client) *RedisCache { + return &RedisCache{ + client: client, + prefix: "teldrive:", + ctx: ctx, + } +} + +func (r *RedisCache) Get(key string, value interface{}) error { + key = r.prefix + key + data, err := r.client.Get(r.ctx, key).Bytes() if err != nil { return err } - return c.cache.Set([]byte(key), bytes, int(expires.Seconds())) + return msgpack.Unmarshal(data, value) } -func (c *Cache) Delete(key string) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.cache.Del([]byte(key)) - return nil -} - -var ( - defaultCache *Cache - defaultCacheOnce sync.Once -) - -type Config struct { - Size int -} - -var conf = &Config{ - Size: 5 * 1024 * 1024, -} - -func SetConfig(c *Config) { - conf = &Config{ - Size: c.Size, +func (r *RedisCache) Set(key string, value interface{}, expiration time.Duration) error { + key = r.prefix + key + data, err := msgpack.Marshal(value) + if err != nil { + return err } + return r.client.Set(r.ctx, key, data, expiration).Err() } -func DefaultCache() *Cache { - defaultCacheOnce.Do(func() { - defaultCache = &Cache{cache: freecache.NewCache(conf.Size)} - }) - return defaultCache -} - -type cacheKeyType string - -var contextKey = cacheKeyType("cache") - -func WithCache(ctx context.Context, cache *Cache) context.Context { - if gCtx, ok := ctx.(*gin.Context); ok { - ctx = gCtx.Request.Context() - } - return context.WithValue(ctx, contextKey, cache) -} - -func FromContext(ctx context.Context) *Cache { - if ctx == nil { - return DefaultCache() - } - if gCtx, ok := ctx.(*gin.Context); ok && gCtx != nil { - ctx = gCtx.Request.Context() - } - if cache, ok := ctx.Value(contextKey).(*Cache); ok { - return cache - } - return DefaultCache() +func (r *RedisCache) Delete(keys ...string) error { + for i := range keys { + keys[i] = r.prefix + keys[i] + } + return r.client.Del(r.ctx, keys...).Err() } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 9fa2e8f..38dea2b 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -1,7 +1,6 @@ package cache import ( - "context" "testing" "time" @@ -10,8 +9,6 @@ import ( ) func TestCache(t *testing.T) { - ctx := context.Background() - cache := FromContext(ctx) var value = schemas.FileIn{ Name: "file.jpeg", @@ -19,7 +16,9 @@ func TestCache(t *testing.T) { } var result schemas.FileIn - err := cache.Set("key", value, 1*time.Minute) + cache := NewMemoryCache(1 * 1024 * 1024) + + err := cache.Set("key", value, 1*time.Second) assert.NoError(t, err) err = cache.Get("key", &result) diff --git a/internal/config/config.go b/internal/config/config.go index 3e5cb78..6ba3a9a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,12 @@ type Config struct { DB DBConfig TG TGConfig CronJobs CronJobConfig + Cache struct { + Type string + MaxSize int + RedisAddr string + RedisPass string + } } type ServerConfig struct { @@ -22,7 +28,10 @@ type ServerConfig struct { } type CronJobConfig struct { - Enable bool + Enable bool + CleanFilesInterval time.Duration + CleanUploadsInterval time.Duration + FolderSizeInterval time.Duration } type TGConfig struct { diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 58348ff..7552117 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -33,14 +33,12 @@ func Cors() gin.HandlerFunc { return cors.New(cors.Config{ AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}, AllowHeaders: []string{"Authorization", "Content-Length", "Content-Type"}, - AllowOriginFunc: func(origin string) bool { - return true - }, - MaxAge: 12 * time.Hour, + AllowOrigins: []string{"*"}, + MaxAge: 12 * time.Hour, }) } -func Authmiddleware(secret string, db *gorm.DB, cache *cache.Cache) gin.HandlerFunc { +func Authmiddleware(secret string, db *gorm.DB, cache cache.Cacher) gin.HandlerFunc { return func(c *gin.Context) { user, err := auth.VerifyUser(c, db, cache, secret) if err != nil { diff --git a/internal/reader/decrypted_reader.go b/internal/reader/decrypted_reader.go index 1b43231..168fcfb 100644 --- a/internal/reader/decrypted_reader.go +++ b/internal/reader/decrypted_reader.go @@ -4,6 +4,7 @@ import ( "context" "io" + "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/internal/tgc" @@ -23,6 +24,7 @@ type decrpytedReader struct { client *tgc.Client fileId string concurrency int + cache cache.Cacher } func NewDecryptedReader( @@ -34,7 +36,8 @@ func NewDecryptedReader( config *config.TGConfig, concurrency int, client *tgc.Client, - worker *tgc.StreamWorker) (*decrpytedReader, error) { + worker *tgc.StreamWorker, + cache cache.Cacher) (*decrpytedReader, error) { r := &decrpytedReader{ ctx: ctx, @@ -47,6 +50,7 @@ func NewDecryptedReader( channelId: channelId, fileId: fileId, concurrency: concurrency, + cache: cache, } res, err := r.nextPart() @@ -111,7 +115,7 @@ func (r *decrpytedReader) nextPart() (io.ReadCloser, error) { } chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, - client: r.client, concurrency: r.concurrency} + client: r.client, concurrency: r.concurrency, cache: r.cache} if r.concurrency < 2 { return newTGReader(r.ctx, underlyingOffset, end, chunkSrc) diff --git a/internal/reader/reader.go b/internal/reader/reader.go index fb896ca..59e439e 100644 --- a/internal/reader/reader.go +++ b/internal/reader/reader.go @@ -4,6 +4,7 @@ import ( "context" "io" + "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/tgc" "github.com/divyam234/teldrive/pkg/types" @@ -49,6 +50,7 @@ type linearReader struct { client *tgc.Client fileId string concurrency int + cache cache.Cacher } func NewLinearReader(ctx context.Context, @@ -60,6 +62,7 @@ func NewLinearReader(ctx context.Context, concurrency int, client *tgc.Client, worker *tgc.StreamWorker, + cache cache.Cacher, ) (reader io.ReadCloser, err error) { r := &linearReader{ @@ -73,6 +76,7 @@ func NewLinearReader(ctx context.Context, channelId: channelId, fileId: fileId, concurrency: concurrency, + cache: cache, } r.reader, err = r.nextPart() @@ -116,7 +120,7 @@ func (r *linearReader) nextPart() (io.ReadCloser, error) { chunkSrc := &chunkSource{channelId: r.channelId, worker: r.worker, fileId: r.fileId, partId: r.parts[r.ranges[r.pos].PartNo].ID, - client: r.client, concurrency: r.concurrency} + client: r.client, concurrency: r.concurrency, cache: r.cache} if r.concurrency < 2 { return newTGReader(r.ctx, start, end, chunkSrc) } diff --git a/internal/reader/tg_multi_reader.go b/internal/reader/tg_multi_reader.go index 51f8f8c..aa59a61 100644 --- a/internal/reader/tg_multi_reader.go +++ b/internal/reader/tg_multi_reader.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" "github.com/divyam234/teldrive/internal/tgc" "github.com/gotd/td/tg" @@ -28,6 +29,7 @@ type chunkSource struct { partId int64 concurrency int client *tgc.Client + cache cache.Cacher } func (c *chunkSource) ChunkSize(start, end int64) int64 { @@ -52,7 +54,7 @@ func (c *chunkSource) Chunk(ctx context.Context, offset int64, limit int64) ([]b if c.concurrency > 0 { client, _, _ = c.worker.Next(c.channelId) } - location, err = tgc.GetLocation(ctx, client, c.fileId, c.channelId, c.partId) + location, err = tgc.GetLocation(ctx, client, c.cache, c.fileId, c.channelId, c.partId) if err != nil { return nil, err diff --git a/internal/tgc/helpers.go b/internal/tgc/helpers.go index 01e87bb..22f4754 100644 --- a/internal/tgc/helpers.go +++ b/internal/tgc/helpers.go @@ -8,6 +8,7 @@ import ( "math" "runtime" "sync" + "time" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/config" @@ -198,9 +199,7 @@ func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token st return &types.BotInfo{Id: user.ID, UserName: user.Username, Token: token}, nil } -func GetLocation(ctx context.Context, client *Client, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) { - - cache := cache.FromContext(ctx) +func GetLocation(ctx context.Context, client *Client, cache cache.Cacher, fileId string, channelId int64, partId int64) (location *tg.InputDocumentFileLocation, err error) { key := fmt.Sprintf("files:location:%s:%s:%d", client.UserId, fileId, partId) @@ -235,7 +234,7 @@ func GetLocation(ctx context.Context, client *Client, fileId string, channelId i media := item.Media.(*tg.MessageMediaDocument) document := media.Document.(*tg.Document) location = document.AsInputDocumentFileLocation() - cache.Set(key, location, 1800) + cache.Set(key, location, 30*time.Minute) } } return location, nil diff --git a/pkg/controller/file.go b/pkg/controller/file.go index eed0b96..aacdd71 100644 --- a/pkg/controller/file.go +++ b/pkg/controller/file.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/divyam234/teldrive/internal/auth" - "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/pkg/httputil" "github.com/divyam234/teldrive/pkg/schemas" "github.com/gin-gonic/gin" @@ -39,7 +38,7 @@ func (fc *Controller) UpdateFile(c *gin.Context) { httputil.NewError(c, http.StatusBadRequest, err) return } - res, err := fc.FileService.UpdateFile(c.Param("fileID"), userId, &fileUpdate, cache.FromContext(c)) + res, err := fc.FileService.UpdateFile(c.Param("fileID"), userId, &fileUpdate) if err != nil { httputil.NewError(c, err.Code, err.Error) return @@ -150,13 +149,15 @@ func (fc *Controller) DeleteFiles(c *gin.Context) { func (fc *Controller) UpdateParts(c *gin.Context) { + userId, _ := auth.GetUser(c) + var payload schemas.PartUpdate if err := c.ShouldBindJSON(&payload); err != nil { httputil.NewError(c, http.StatusBadRequest, err) return } - res, err := fc.FileService.UpdateParts(c, c.Param("fileID"), &payload) + res, err := fc.FileService.UpdateParts(c, c.Param("fileID"), userId, &payload) if err != nil { httputil.NewError(c, err.Code, err.Error) return diff --git a/pkg/cron/cron.go b/pkg/cron/cron.go index 795a578..4352778 100644 --- a/pkg/cron/cron.go +++ b/pkg/cron/cron.go @@ -49,11 +49,11 @@ func StartCronJobs(scheduler *gocron.Scheduler, db *gorm.DB, cnf *config.Config) cron := CronService{db: db, cnf: cnf, logger: logging.DefaultLogger()} - scheduler.Every(1).Hour().Do(cron.CleanFiles, ctx) + scheduler.Every(cnf.CronJobs.CleanFilesInterval).Do(cron.CleanFiles, ctx) - scheduler.Every(2).Hour().Do(cron.UpdateFolderSize) + scheduler.Every(cnf.CronJobs.FolderSizeInterval).Do(cron.UpdateFolderSize) - scheduler.Every(12).Hour().Do(cron.CleanUploads, ctx) + scheduler.Every(cnf.CronJobs.CleanUploadsInterval).Do(cron.CleanUploads, ctx) scheduler.StartAsync() } diff --git a/pkg/mapper/mapper.go b/pkg/mapper/mapper.go index 33485c3..8638368 100644 --- a/pkg/mapper/mapper.go +++ b/pkg/mapper/mapper.go @@ -36,7 +36,6 @@ func ToFileOutFull(file models.File) *schemas.FileOutFull { FileOut: ToFileOut(file), Parts: file.Parts, ChannelID: channelId, - Encrypted: file.Encrypted, } } diff --git a/pkg/schemas/file.go b/pkg/schemas/file.go index 4bb2256..aa175e5 100644 --- a/pkg/schemas/file.go +++ b/pkg/schemas/file.go @@ -57,7 +57,6 @@ type FileOutFull struct { *FileOut Parts []Part `json:"parts,omitempty"` ChannelID int64 `json:"channelId,omitempty"` - Encrypted bool `json:"encrypted"` } type FileUpdate struct { diff --git a/pkg/services/auth.go b/pkg/services/auth.go index cf990c8..39c6c2f 100644 --- a/pkg/services/auth.go +++ b/pkg/services/auth.go @@ -39,10 +39,10 @@ import ( type AuthService struct { db *gorm.DB cnf *config.Config - cache *cache.Cache + cache cache.Cacher } -func NewAuthService(db *gorm.DB, cnf *config.Config, cache *cache.Cache) *AuthService { +func NewAuthService(db *gorm.DB, cnf *config.Config, cache cache.Cacher) *AuthService { return &AuthService{db: db, cnf: cnf, cache: cache} } @@ -188,8 +188,7 @@ func (as *AuthService) Logout(c *gin.Context) (*schemas.Message, *types.AppError }) setSessionCookie(c, "", -1) as.db.Where("session = ?", jwtUser.TgSession).Delete(&models.Session{}) - cache := cache.FromContext(c) - cache.Delete(fmt.Sprintf("sessions:%s", jwtUser.Hash)) + as.cache.Delete(fmt.Sprintf("sessions:%s", jwtUser.Hash)) return &schemas.Message{Message: "logout success"}, nil } diff --git a/pkg/services/common.go b/pkg/services/common.go index 7ca205e..5389b9d 100644 --- a/pkg/services/common.go +++ b/pkg/services/common.go @@ -3,6 +3,7 @@ package services import ( "context" "fmt" + "time" "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/crypt" @@ -15,8 +16,8 @@ import ( "gorm.io/gorm" ) -func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, userID string) ([]types.Part, error) { - cache := cache.FromContext(ctx) +func getParts(ctx context.Context, client *tg.Client, cache cache.Cacher, file *schemas.FileOutFull, userID string) ([]types.Part, error) { + parts := []types.Part{} key := fmt.Sprintf("files:messages:%s:%s", file.Id, userID) @@ -52,12 +53,12 @@ func getParts(ctx context.Context, client *tg.Client, file *schemas.FileOutFull, } parts = append(parts, part) } - cache.Set(key, &parts, 3600) + cache.Set(key, &parts, 60*time.Minute) return parts, nil } -func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, error) { - cache := cache.FromContext(ctx) +func getDefaultChannel(db *gorm.DB, cache cache.Cacher, userID int64) (int64, error) { + var channelId int64 key := fmt.Sprintf("users:channel:%d", userID) @@ -83,8 +84,7 @@ func getDefaultChannel(ctx context.Context, db *gorm.DB, userID int64) (int64, e return channelId, nil } -func getBotsToken(ctx context.Context, db *gorm.DB, userID, channelId int64) ([]string, error) { - cache := cache.FromContext(ctx) +func getBotsToken(db *gorm.DB, cache cache.Cacher, userID, channelId int64) ([]string, error) { var bots []string key := fmt.Sprintf("users:bots:%d:%d", userID, channelId) diff --git a/pkg/services/file.go b/pkg/services/file.go index 2fc0a91..cfd6661 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -78,10 +78,10 @@ type FileService struct { db *gorm.DB cnf *config.Config worker *tgc.StreamWorker - cache *cache.Cache + cache cache.Cacher } -func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache *cache.Cache) *FileService { +func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache cache.Cacher) *FileService { return &FileService{db: db, cnf: cnf, worker: worker, cache: cache} } @@ -115,7 +115,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas. channelId := fileIn.ChannelID if fileIn.ChannelID == 0 { var err error - channelId, err = getDefaultChannel(c, fs.db, userId) + channelId, err = getDefaultChannel(fs.db, fs.cache, userId) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusNotFound} } @@ -145,7 +145,7 @@ func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas. return res, nil } -func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileUpdate, cache *cache.Cache) (*schemas.FileOut, *types.AppError) { +func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileUpdate) (*schemas.FileOut, *types.AppError) { var ( files []models.File chain *gorm.DB @@ -175,14 +175,7 @@ func (fs *FileService) UpdateFile(id string, userId int64, update *schemas.FileU return nil, &types.AppError{Error: database.ErrNotFound, Code: http.StatusNotFound} } - cache.Delete(fmt.Sprintf("files:%s", id)) - - if len(update.Parts) > 0 { - cache.Delete(fmt.Sprintf("files:messages:%s:%d", id, userId)) - for _, part := range files[0].Parts { - cache.Delete(fmt.Sprintf("files:location:%d:%s:%d", userId, id, part.ID)) - } - } + fs.cache.Delete(fmt.Sprintf("files:%s", id)) return mapper.ToFileOut(files[0]), nil @@ -403,7 +396,7 @@ func (fs *FileService) DeleteFiles(userId int64, payload *schemas.DeleteOperatio return &schemas.Message{Message: "files deleted"}, nil } -func (fs *FileService) UpdateParts(c *gin.Context, id string, payload *schemas.PartUpdate) (*schemas.Message, *types.AppError) { +func (fs *FileService) UpdateParts(c *gin.Context, id string, userId int64, payload *schemas.PartUpdate) (*schemas.Message, *types.AppError) { var file models.File @@ -447,6 +440,13 @@ func (fs *FileService) UpdateParts(c *gin.Context, id string, payload *schemas.P } client, _ := tgc.AuthClient(c, &fs.cnf.TG, session) tgc.DeleteMessages(c, client, *file.ChannelID, ids) + keys := []string{fmt.Sprintf("files:%s", id), fmt.Sprintf("files:messages:%s:%d", id, userId)} + for _, part := range file.Parts { + keys = append(keys, fmt.Sprintf("files:location:%d:%s:%d", userId, id, part.ID)) + + } + fs.cache.Delete(keys...) + } return &schemas.Message{Message: "file updated"}, nil @@ -497,7 +497,7 @@ func (fs *FileService) CopyFile(c *gin.Context) (*schemas.FileOut, *types.AppErr newIds := []schemas.Part{} - channelId, err := getDefaultChannel(c, fs.db, userId) + channelId, err := getDefaultChannel(fs.db, fs.cache, userId) if err != nil { return nil, &types.AppError{Error: err} } @@ -706,7 +706,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { c.Header("Content-Disposition", mime.FormatMediaType(disposition, map[string]string{"filename": file.Name})) - tokens, err := getBotsToken(c, fs.db, session.UserId, file.ChannelID) + tokens, err := getBotsToken(fs.db, fs.cache, session.UserId, file.ChannelID) logger := logging.FromContext(c) if err != nil { @@ -753,7 +753,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { } if r.Method != "HEAD" { - parts, err := getParts(c, client.Tg.API(), file, channelUser) + parts, err := getParts(c, client.Tg.API(), fs.cache, file, channelUser) if err != nil { logger.Error(ErrorStreamAbandoned, err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -764,9 +764,9 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) { multiThreads = 0 } if file.Encrypted { - lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker) + lr, err = reader.NewDecryptedReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache) } else { - lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker) + lr, err = reader.NewLinearReader(c, file.Id, parts, start, end, file.ChannelID, &fs.cnf.TG, multiThreads, client, fs.worker, fs.cache) } if err != nil { diff --git a/pkg/services/file_test.go b/pkg/services/file_test.go index 12f4aba..cf4f645 100644 --- a/pkg/services/file_test.go +++ b/pkg/services/file_test.go @@ -81,7 +81,7 @@ func (s *FileServiceSuite) Test_Update() { Name: "file3.jpeg", Type: "file", } - r, err := s.srv.UpdateFile(res.Id, 123456, data, nil) + r, err := s.srv.UpdateFile(res.Id, 123456, data) s.NoError(err.Error) s.Equal(r.Name, data.Name) } diff --git a/pkg/services/upload.go b/pkg/services/upload.go index 1eb1af0..c8d3871 100644 --- a/pkg/services/upload.go +++ b/pkg/services/upload.go @@ -14,6 +14,7 @@ import ( "time" "github.com/divyam234/teldrive/internal/auth" + "github.com/divyam234/teldrive/internal/cache" "github.com/divyam234/teldrive/internal/crypt" "github.com/divyam234/teldrive/internal/kv" "github.com/divyam234/teldrive/internal/logging" @@ -41,10 +42,11 @@ type UploadService struct { worker *tgc.UploadWorker cnf *config.TGConfig kv kv.KV + cache cache.Cacher } -func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV) *UploadService { - return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv} +func NewUploadService(db *gorm.DB, cnf *config.Config, worker *tgc.UploadWorker, kv kv.KV, cache cache.Cacher) *UploadService { + return &UploadService{db: db, worker: worker, cnf: &cnf.TG, kv: kv, cache: cache} } func (us *UploadService) GetUploadFileById(c *gin.Context) (*schemas.UploadOut, *types.AppError) { @@ -128,7 +130,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty defer fileStream.Close() if uploadQuery.ChannelID == 0 { - channelId, err = getDefaultChannel(c, us.db, userId) + channelId, err = getDefaultChannel(us.db, us.cache, userId) if err != nil { return nil, &types.AppError{Error: err} } @@ -136,7 +138,7 @@ func (us *UploadService) UploadFile(c *gin.Context) (*schemas.UploadPartOut, *ty channelId = uploadQuery.ChannelID } - tokens, err := getBotsToken(c, us.db, userId, channelId) + tokens, err := getBotsToken(us.db, us.cache, userId, channelId) if err != nil { return nil, &types.AppError{Error: err} diff --git a/pkg/services/upload_test.go b/pkg/services/upload_test.go index 23df76b..6aebaca 100644 --- a/pkg/services/upload_test.go +++ b/pkg/services/upload_test.go @@ -18,7 +18,7 @@ type UploadServiceSuite struct { func (s *UploadServiceSuite) SetupSuite() { s.db = database.NewTestDatabase(s.T(), false) - s.srv = NewUploadService(s.db, nil, nil, nil) + s.srv = NewUploadService(s.db, nil, nil, nil, nil) } func (s *UploadServiceSuite) SetupTest() { diff --git a/pkg/services/user.go b/pkg/services/user.go index 9ee04c8..c271082 100644 --- a/pkg/services/user.go +++ b/pkg/services/user.go @@ -32,13 +32,14 @@ import ( ) type UserService struct { - db *gorm.DB - cnf *config.Config - kv kv.KV + db *gorm.DB + cnf *config.Config + kv kv.KV + cache cache.Cacher } -func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV) *UserService { - return &UserService{db: db, cnf: cnf, kv: kv} +func NewUserService(db *gorm.DB, cnf *config.Config, kv kv.KV, cache cache.Cacher) *UserService { + return &UserService{db: db, cnf: cnf, kv: kv, cache: cache} } func (us *UserService) GetProfilePhoto(c *gin.Context) { _, session := auth.GetUser(c) @@ -89,9 +90,9 @@ func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.A err error ) - channelId, _ = getDefaultChannel(c, us.db, userID) + channelId, _ = getDefaultChannel(us.db, us.cache, userID) - tokens, err := getBotsToken(c, us.db, userID, channelId) + tokens, err := getBotsToken(us.db, us.cache, userID, channelId) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} @@ -101,8 +102,6 @@ func (us *UserService) GetStats(c *gin.Context) (*schemas.AccountStats, *types.A func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.AppError) { - cache := cache.FromContext(c) - userId, _ := auth.GetUser(c) var payload schemas.Channel @@ -125,7 +124,7 @@ func (us *UserService) UpdateChannel(c *gin.Context) (*schemas.Message, *types.A Where("user_id = ?", userId).Update("selected", false) key := fmt.Sprintf("users:channel:%d", userId) - cache.Set(key, payload.ChannelID, 0) + us.cache.Set(key, payload.ChannelID, 0) return &schemas.Message{Message: "channel updated"}, nil } @@ -258,7 +257,7 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro return &schemas.Message{Message: "no bots to add"}, nil } - channelId, err := getDefaultChannel(c, us.db, userId) + channelId, err := getDefaultChannel(us.db, us.cache, userId) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} @@ -270,11 +269,9 @@ func (us *UserService) AddBots(c *gin.Context) (*schemas.Message, *types.AppErro func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppError) { - cache := cache.FromContext(c) - userID, _ := auth.GetUser(c) - channelId, err := getDefaultChannel(c, us.db, userID) + channelId, err := getDefaultChannel(us.db, us.cache, userID) if err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} @@ -285,7 +282,7 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError} } - cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId)) + us.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userID, channelId)) return &schemas.Message{Message: "bots deleted"}, nil @@ -293,8 +290,6 @@ func (us *UserService) RemoveBots(c *gin.Context) (*schemas.Message, *types.AppE func (us *UserService) addBots(c context.Context, client *telegram.Client, userId int64, channelId int64, botsTokens []string) (*schemas.Message, *types.AppError) { - cache := cache.FromContext(c) - botInfoMap := make(map[string]*types.BotInfo) err := tgc.RunWithAuth(c, client, "", func(ctx context.Context) error { @@ -379,7 +374,7 @@ func (us *UserService) addBots(c context.Context, client *telegram.Client, userI }) } - cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId)) + us.cache.Delete(fmt.Sprintf("users:bots:%d:%d", userId, channelId)) if err := us.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&payload).Error; err != nil { return nil, &types.AppError{Error: err, Code: http.StatusInternalServerError}