mirror of
https://github.com/tgdrive/teldrive.git
synced 2024-09-20 08:15:55 +08:00
refactor: Switch to golang-jwt for auth
This commit is contained in:
parent
e7865a7637
commit
20d39cd5b9
|
@ -1,15 +1,17 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/middleware"
|
||||
"github.com/divyam234/teldrive/pkg/controller"
|
||||
"github.com/divyam234/teldrive/ui"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gin.Engine {
|
||||
authmiddleware := middleware.Authmiddleware(cnf.JWT.Secret)
|
||||
func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config, db *gorm.DB, cache *cache.Cache) *gin.Engine {
|
||||
authmiddleware := middleware.Authmiddleware(cnf.JWT.Secret, db, cache)
|
||||
api := r.Group("/api")
|
||||
{
|
||||
auth := api.Group("/auth")
|
||||
|
|
15
cmd/run.go
15
cmd/run.go
|
@ -12,6 +12,7 @@ import (
|
|||
"unicode"
|
||||
|
||||
"github.com/divyam234/teldrive/api"
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/internal/config"
|
||||
"github.com/divyam234/teldrive/internal/database"
|
||||
"github.com/divyam234/teldrive/internal/duration"
|
||||
|
@ -34,6 +35,7 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func NewRun() *cobra.Command {
|
||||
|
@ -134,12 +136,9 @@ func runApplication(conf *config.Config) {
|
|||
fx.Supply(logging.DefaultLogger().Desugar()),
|
||||
fx.NopLogger,
|
||||
fx.StopTimeout(conf.Server.GracefulShutdown+time.Second),
|
||||
fx.Invoke(
|
||||
initApp,
|
||||
cron.StartCronJobs,
|
||||
),
|
||||
fx.Provide(
|
||||
database.NewDatabase,
|
||||
cache.DefaultCache,
|
||||
kv.NewBoltKV,
|
||||
tgc.NewStreamWorker(tgContext),
|
||||
tgc.NewUploadWorker,
|
||||
|
@ -149,6 +148,10 @@ func runApplication(conf *config.Config) {
|
|||
services.NewUserService,
|
||||
controller.NewController,
|
||||
),
|
||||
fx.Invoke(
|
||||
initApp,
|
||||
cron.StartCronJobs,
|
||||
),
|
||||
)
|
||||
|
||||
app.Run()
|
||||
|
@ -223,7 +226,7 @@ func modifyFlag(s string) string {
|
|||
return string(result)
|
||||
}
|
||||
|
||||
func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller) *gin.Engine {
|
||||
func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller, db *gorm.DB, cache *cache.Cache) *gin.Engine {
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
|
@ -258,7 +261,7 @@ func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller) *gin
|
|||
c.Next()
|
||||
})
|
||||
|
||||
r = api.InitRouter(r, c, cfg)
|
||||
r = api.InitRouter(r, c, cfg, db, cache)
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: r,
|
||||
|
|
2
go.mod
2
go.mod
|
@ -10,7 +10,7 @@ require (
|
|||
github.com/gin-contrib/zap v1.1.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-co-op/gocron v1.37.0
|
||||
github.com/go-jose/go-jose/v3 v3.0.3
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/gotd/contrib v0.20.0
|
||||
github.com/gotd/td v0.105.0
|
||||
github.com/iyear/connectproxy v0.1.1
|
||||
|
|
19
go.sum
19
go.sum
|
@ -58,8 +58,6 @@ github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb
|
|||
github.com/go-faster/xor v0.3.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
|
||||
github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38=
|
||||
github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
|
||||
github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7G7k=
|
||||
github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
|
@ -73,6 +71,8 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
|
|||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
|
@ -82,7 +82,6 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
|
|||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
|
@ -245,24 +244,19 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
|||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
@ -272,22 +266,14 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
|
@ -295,7 +281,6 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
|||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
|
|
96
internal/auth/auth.go
Normal file
96
internal/auth/auth.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/divyam234/teldrive/pkg/models"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Encode(secret string, claims *types.JWTClaims) (string, error) {
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func Decode(secret string, token string) (*types.JWTClaims, error) {
|
||||
claims := &types.JWTClaims{}
|
||||
|
||||
tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !tkn.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
return claims, err
|
||||
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) (int64, string) {
|
||||
val, _ := c.Get("jwtUser")
|
||||
jwtUser := val.(*types.JWTClaims)
|
||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||
return userId, jwtUser.TgSession
|
||||
}
|
||||
|
||||
func VerifyUser(c *gin.Context, db *gorm.DB, cache *cache.Cache, secret string) (*types.JWTClaims, error) {
|
||||
var token string
|
||||
cookie, err := c.Request.Cookie("user-session")
|
||||
|
||||
if err != nil {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
bearerToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(bearerToken) != 2 {
|
||||
return nil, fmt.Errorf("missing auth token")
|
||||
}
|
||||
token = bearerToken[1]
|
||||
} else {
|
||||
token = cookie.Value
|
||||
}
|
||||
|
||||
claims, err := Decode(secret, token)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var session *models.Session
|
||||
|
||||
session, err = GetSessionByHash(db, cache, claims.Hash)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid session")
|
||||
}
|
||||
|
||||
claims.TgSession = session.Session
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func GetSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) {
|
||||
var session models.Session
|
||||
|
||||
key := fmt.Sprintf("sessions:%s", hash)
|
||||
|
||||
err := cache.Get(key, &session)
|
||||
|
||||
if err != nil {
|
||||
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache.Set(key, &session, 0)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
)
|
||||
|
||||
func Encode(secret string, payload *types.JWTClaims) (string, error) {
|
||||
|
||||
rcpt := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: secret,
|
||||
}
|
||||
|
||||
enc, err := jose.NewEncrypter(jose.A128CBC_HS256, rcpt, nil)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jwt, _ := json.Marshal(payload)
|
||||
|
||||
jweObject, err := enc.Encrypt(jwt)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jweToken, err := jweObject.CompactSerialize()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return jweToken, nil
|
||||
}
|
||||
|
||||
func Decode(secret string, token string) (*types.JWTClaims, error) {
|
||||
jwe, err := jose.ParseEncrypted(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decryptedData, err := jwe.Decrypt(secret)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jwtToken := &types.JWTClaims{}
|
||||
|
||||
err = json.Unmarshal(decryptedData, jwtToken)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jwtToken, nil
|
||||
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) (int64, string) {
|
||||
val, _ := c.Get("jwtUser")
|
||||
jwtUser := val.(*types.JWTClaims)
|
||||
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
|
||||
return userId, jwtUser.TgSession
|
||||
}
|
||||
|
||||
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
|
||||
var token string
|
||||
cookie, err := c.Request.Cookie("user-session")
|
||||
|
||||
if err != nil {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
bearerToken := strings.Split(authHeader, "Bearer ")
|
||||
if len(bearerToken) != 2 {
|
||||
return nil, fmt.Errorf("missing auth token")
|
||||
}
|
||||
token = bearerToken[1]
|
||||
} else {
|
||||
token = cookie.Value
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
jwePayload, err := Decode(secret, token)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
|
||||
}
|
||||
return jwePayload, nil
|
||||
}
|
|
@ -7,7 +7,9 @@ import (
|
|||
|
||||
"github.com/divyam234/cors"
|
||||
"github.com/divyam234/teldrive/internal/auth"
|
||||
"github.com/divyam234/teldrive/internal/cache"
|
||||
"github.com/gin-contrib/secure"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
@ -38,9 +40,9 @@ func Cors() gin.HandlerFunc {
|
|||
})
|
||||
}
|
||||
|
||||
func Authmiddleware(secret string) gin.HandlerFunc {
|
||||
func Authmiddleware(secret string, db *gorm.DB, cache *cache.Cache) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user, err := auth.VerifyUser(c, secret)
|
||||
user, err := auth.VerifyUser(c, db, cache, secret)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
|
|
@ -20,6 +20,7 @@ var internalErrors = []string{
|
|||
"memory limit exit",
|
||||
"connection dead",
|
||||
"engine was closed",
|
||||
"STORAGE_CHOOSE_VOLUME_FAILED",
|
||||
}
|
||||
|
||||
type retry struct {
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
"github.com/divyam234/teldrive/pkg/schemas"
|
||||
"github.com/divyam234/teldrive/pkg/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gotd/td/session"
|
||||
tgauth "github.com/gotd/td/telegram/auth"
|
||||
|
@ -36,12 +36,13 @@ import (
|
|||
)
|
||||
|
||||
type AuthService struct {
|
||||
db *gorm.DB
|
||||
cnf *config.Config
|
||||
db *gorm.DB
|
||||
cnf *config.Config
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewAuthService(db *gorm.DB, cnf *config.Config) *AuthService {
|
||||
return &AuthService{db: db, cnf: cnf}
|
||||
func NewAuthService(db *gorm.DB, cnf *config.Config, cache *cache.Cache) *AuthService {
|
||||
return &AuthService{db: db, cnf: cnf, cache: cache}
|
||||
|
||||
}
|
||||
|
||||
|
@ -54,16 +55,16 @@ func (as *AuthService) LogIn(c *gin.Context, session *schemas.TgSession) (*schem
|
|||
|
||||
now := time.Now().UTC()
|
||||
|
||||
jwtClaims := &types.JWTClaims{Claims: jwt.Claims{
|
||||
Subject: strconv.FormatInt(session.UserID, 10),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
Expiry: jwt.NewNumericDate(now.Add(as.cnf.JWT.SessionTime)),
|
||||
}, TgSession: session.Sesssion,
|
||||
jwtClaims := &types.JWTClaims{
|
||||
Name: session.Name,
|
||||
UserName: session.UserName,
|
||||
Bot: session.Bot,
|
||||
IsPremium: session.IsPremium,
|
||||
}
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: strconv.FormatInt(session.UserID, 10),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(as.cnf.JWT.SessionTime)),
|
||||
}}
|
||||
|
||||
tokenhash := md5.Sum([]byte(session.Sesssion))
|
||||
hexToken := hex.EncodeToString(tokenhash[:])
|
||||
|
@ -144,43 +145,31 @@ func (as *AuthService) LogIn(c *gin.Context, session *schemas.TgSession) (*schem
|
|||
|
||||
func (as *AuthService) GetSession(c *gin.Context) *schemas.Session {
|
||||
|
||||
cookie, err := c.Request.Cookie("user-session")
|
||||
claims, err := auth.VerifyUser(c, as.db, as.cache, as.cnf.JWT.Secret)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
jwePayload, err := auth.Decode(as.cnf.JWT.Secret, cookie.Value)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cache := cache.FromContext(c)
|
||||
|
||||
_, err = getSessionByHash(as.db, cache, jwePayload.Hash)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
claims.TgSession = ""
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
newExpires := now.Add(as.cnf.JWT.SessionTime)
|
||||
|
||||
userId, _ := strconv.ParseInt(jwePayload.Subject, 10, 64)
|
||||
userId, _ := strconv.ParseInt(claims.Subject, 10, 64)
|
||||
|
||||
session := &schemas.Session{Name: jwePayload.Name,
|
||||
UserName: jwePayload.UserName,
|
||||
session := &schemas.Session{Name: claims.Name,
|
||||
UserName: claims.UserName,
|
||||
UserId: userId,
|
||||
Hash: jwePayload.Hash,
|
||||
Hash: claims.Hash,
|
||||
Expires: newExpires.Format(time.RFC3339)}
|
||||
|
||||
jwePayload.IssuedAt = jwt.NewNumericDate(now)
|
||||
claims.IssuedAt = jwt.NewNumericDate(now)
|
||||
|
||||
jwePayload.Expiry = jwt.NewNumericDate(newExpires)
|
||||
claims.ExpiresAt = jwt.NewNumericDate(newExpires)
|
||||
|
||||
jweToken, err := auth.Encode(as.cnf.JWT.Secret, jwePayload)
|
||||
jweToken, err := auth.Encode(as.cnf.JWT.Secret, claims)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
|
|
|
@ -104,21 +104,3 @@ func getBotsToken(ctx context.Context, db *gorm.DB, userID, channelId int64) ([]
|
|||
return bots, nil
|
||||
|
||||
}
|
||||
|
||||
func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) {
|
||||
var session models.Session
|
||||
|
||||
key := fmt.Sprintf("sessions:%s", hash)
|
||||
|
||||
err := cache.Get(key, &session)
|
||||
|
||||
if err != nil {
|
||||
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache.Set(key, &session, 0)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
|
||||
}
|
||||
|
|
|
@ -78,10 +78,11 @@ type FileService struct {
|
|||
db *gorm.DB
|
||||
cnf *config.Config
|
||||
worker *tgc.StreamWorker
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
|
||||
return &FileService{db: db, cnf: cnf, worker: worker}
|
||||
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache *cache.Cache) *FileService {
|
||||
return &FileService{db: db, cnf: cnf, worker: worker, cache: cache}
|
||||
}
|
||||
|
||||
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
|
||||
|
@ -566,8 +567,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
authHash := c.Query("hash")
|
||||
|
||||
cache := cache.FromContext(c)
|
||||
|
||||
var (
|
||||
session *models.Session
|
||||
err error
|
||||
|
@ -576,7 +575,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
)
|
||||
|
||||
if authHash == "" {
|
||||
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
|
||||
user, err = auth.VerifyUser(c, fs.db, fs.cache, fs.cnf.JWT.Secret)
|
||||
if err != nil {
|
||||
http.Error(w, "missing session or authash", http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -584,7 +583,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
|
||||
session = &models.Session{UserId: userId, Session: user.TgSession}
|
||||
} else {
|
||||
session, err = getSessionByHash(fs.db, cache, authHash)
|
||||
session, err = auth.GetSessionByHash(fs.db, fs.cache, authHash)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid hash", http.StatusBadRequest)
|
||||
return
|
||||
|
@ -595,7 +594,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
|
||||
key := fmt.Sprintf("files:%s", fileID)
|
||||
|
||||
err = cache.Get(key, file)
|
||||
err = fs.cache.Get(key, file)
|
||||
|
||||
if err != nil {
|
||||
file, appErr = fs.GetFileByID(fileID)
|
||||
|
@ -603,7 +602,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
|
|||
http.Error(w, appErr.Error.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
cache.Set(key, file, 0)
|
||||
fs.cache.Set(key, file, 0)
|
||||
}
|
||||
|
||||
c.Header("Accept-Ranges", "bytes")
|
||||
|
|
|
@ -21,7 +21,7 @@ type FileServiceSuite struct {
|
|||
|
||||
func (s *FileServiceSuite) SetupSuite() {
|
||||
s.db = database.NewTestDatabase(s.T(), false)
|
||||
s.srv = NewFileService(s.db, nil, nil)
|
||||
s.srv = NewFileService(s.db, nil, nil, nil)
|
||||
}
|
||||
|
||||
func (s *FileServiceSuite) SetupTest() {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gotd/td/session"
|
||||
)
|
||||
|
||||
|
@ -18,13 +18,13 @@ type Part struct {
|
|||
}
|
||||
|
||||
type JWTClaims struct {
|
||||
jwt.Claims
|
||||
TgSession string `json:"tgSession"`
|
||||
jwt.RegisteredClaims
|
||||
Name string `json:"name"`
|
||||
UserName string `json:"userName"`
|
||||
Bot bool `json:"bot"`
|
||||
IsPremium bool `json:"isPremium"`
|
||||
Hash string `json:"hash"`
|
||||
TgSession string `json:"tgSession,omitempty"`
|
||||
}
|
||||
|
||||
type SessionData struct {
|
||||
|
|
Loading…
Reference in a new issue