refactor: Switch to golang-jwt for auth

This commit is contained in:
divyam234 2024-07-10 20:24:15 +05:30
parent e7865a7637
commit 20d39cd5b9
13 changed files with 149 additions and 194 deletions

View file

@ -1,15 +1,17 @@
package api
import (
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/middleware"
"github.com/divyam234/teldrive/pkg/controller"
"github.com/divyam234/teldrive/ui"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config) *gin.Engine {
authmiddleware := middleware.Authmiddleware(cnf.JWT.Secret)
func InitRouter(r *gin.Engine, c *controller.Controller, cnf *config.Config, db *gorm.DB, cache *cache.Cache) *gin.Engine {
authmiddleware := middleware.Authmiddleware(cnf.JWT.Secret, db, cache)
api := r.Group("/api")
{
auth := api.Group("/auth")

View file

@ -12,6 +12,7 @@ import (
"unicode"
"github.com/divyam234/teldrive/api"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/internal/config"
"github.com/divyam234/teldrive/internal/database"
"github.com/divyam234/teldrive/internal/duration"
@ -34,6 +35,7 @@ import (
"github.com/spf13/viper"
"go.uber.org/fx"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
)
func NewRun() *cobra.Command {
@ -134,12 +136,9 @@ func runApplication(conf *config.Config) {
fx.Supply(logging.DefaultLogger().Desugar()),
fx.NopLogger,
fx.StopTimeout(conf.Server.GracefulShutdown+time.Second),
fx.Invoke(
initApp,
cron.StartCronJobs,
),
fx.Provide(
database.NewDatabase,
cache.DefaultCache,
kv.NewBoltKV,
tgc.NewStreamWorker(tgContext),
tgc.NewUploadWorker,
@ -149,6 +148,10 @@ func runApplication(conf *config.Config) {
services.NewUserService,
controller.NewController,
),
fx.Invoke(
initApp,
cron.StartCronJobs,
),
)
app.Run()
@ -223,7 +226,7 @@ func modifyFlag(s string) string {
return string(result)
}
func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller) *gin.Engine {
func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller, db *gorm.DB, cache *cache.Cache) *gin.Engine {
gin.SetMode(gin.ReleaseMode)
@ -258,7 +261,7 @@ func initApp(lc fx.Lifecycle, cfg *config.Config, c *controller.Controller) *gin
c.Next()
})
r = api.InitRouter(r, c, cfg)
r = api.InitRouter(r, c, cfg, db, cache)
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: r,

2
go.mod
View file

@ -10,7 +10,7 @@ require (
github.com/gin-contrib/zap v1.1.3
github.com/gin-gonic/gin v1.10.0
github.com/go-co-op/gocron v1.37.0
github.com/go-jose/go-jose/v3 v3.0.3
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/gotd/contrib v0.20.0
github.com/gotd/td v0.105.0
github.com/iyear/connectproxy v0.1.1

19
go.sum
View file

@ -58,8 +58,6 @@ github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb
github.com/go-faster/xor v0.3.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38=
github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ=
github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7G7k=
github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@ -73,6 +71,8 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
@ -82,7 +82,6 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -245,24 +244,19 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -272,22 +266,14 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
@ -295,7 +281,6 @@ golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=

96
internal/auth/auth.go Normal file
View file

@ -0,0 +1,96 @@
package auth
import (
"fmt"
"strconv"
"strings"
"github.com/divyam234/teldrive/internal/cache"
"github.com/divyam234/teldrive/pkg/models"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
)
func Encode(secret string, claims *types.JWTClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret))
}
func Decode(secret string, token string) (*types.JWTClaims, error) {
claims := &types.JWTClaims{}
tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return nil, err
}
if !tkn.Valid {
return nil, fmt.Errorf("invalid token")
}
return claims, err
}
func GetUser(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
func VerifyUser(c *gin.Context, db *gorm.DB, cache *cache.Cache, secret string) (*types.JWTClaims, error) {
var token string
cookie, err := c.Request.Cookie("user-session")
if err != nil {
authHeader := c.GetHeader("Authorization")
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
return nil, fmt.Errorf("missing auth token")
}
token = bearerToken[1]
} else {
token = cookie.Value
}
claims, err := Decode(secret, token)
if err != nil {
return nil, err
}
var session *models.Session
session, err = GetSessionByHash(db, cache, claims.Hash)
if err != nil {
return nil, fmt.Errorf("invalid session")
}
claims.TgSession = session.Session
return claims, nil
}
func GetSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) {
var session models.Session
key := fmt.Sprintf("sessions:%s", hash)
err := cache.Get(key, &session)
if err != nil {
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
return nil, err
}
cache.Set(key, &session, 0)
}
return &session, nil
}

View file

@ -1,104 +0,0 @@
package auth
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
)
func Encode(secret string, payload *types.JWTClaims) (string, error) {
rcpt := jose.Recipient{
Algorithm: jose.PBES2_HS256_A128KW,
Key: secret,
}
enc, err := jose.NewEncrypter(jose.A128CBC_HS256, rcpt, nil)
if err != nil {
return "", err
}
jwt, _ := json.Marshal(payload)
jweObject, err := enc.Encrypt(jwt)
if err != nil {
return "", err
}
jweToken, err := jweObject.CompactSerialize()
if err != nil {
return "", err
}
return jweToken, nil
}
func Decode(secret string, token string) (*types.JWTClaims, error) {
jwe, err := jose.ParseEncrypted(token)
if err != nil {
return nil, err
}
decryptedData, err := jwe.Decrypt(secret)
if err != nil {
return nil, err
}
jwtToken := &types.JWTClaims{}
err = json.Unmarshal(decryptedData, jwtToken)
if err != nil {
return nil, err
}
return jwtToken, nil
}
func GetUser(c *gin.Context) (int64, string) {
val, _ := c.Get("jwtUser")
jwtUser := val.(*types.JWTClaims)
userId, _ := strconv.ParseInt(jwtUser.Subject, 10, 64)
return userId, jwtUser.TgSession
}
func VerifyUser(c *gin.Context, secret string) (*types.JWTClaims, error) {
var token string
cookie, err := c.Request.Cookie("user-session")
if err != nil {
authHeader := c.GetHeader("Authorization")
bearerToken := strings.Split(authHeader, "Bearer ")
if len(bearerToken) != 2 {
return nil, fmt.Errorf("missing auth token")
}
token = bearerToken[1]
} else {
token = cookie.Value
}
now := time.Now().UTC()
jwePayload, err := Decode(secret, token)
if err != nil {
return nil, err
}
if *jwePayload.Expiry < *jwt.NewNumericDate(now) {
return nil, fmt.Errorf("token expired")
}
return jwePayload, nil
}

View file

@ -7,7 +7,9 @@ import (
"github.com/divyam234/cors"
"github.com/divyam234/teldrive/internal/auth"
"github.com/divyam234/teldrive/internal/cache"
"github.com/gin-contrib/secure"
"gorm.io/gorm"
"github.com/gin-gonic/gin"
)
@ -38,9 +40,9 @@ func Cors() gin.HandlerFunc {
})
}
func Authmiddleware(secret string) gin.HandlerFunc {
func Authmiddleware(secret string, db *gorm.DB, cache *cache.Cache) gin.HandlerFunc {
return func(c *gin.Context) {
user, err := auth.VerifyUser(c, secret)
user, err := auth.VerifyUser(c, db, cache, secret)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return

View file

@ -20,6 +20,7 @@ var internalErrors = []string{
"memory limit exit",
"connection dead",
"engine was closed",
"STORAGE_CHOOSE_VOLUME_FAILED",
}
type retry struct {

View file

@ -25,7 +25,7 @@ import (
"github.com/divyam234/teldrive/pkg/schemas"
"github.com/divyam234/teldrive/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/websocket"
"github.com/gotd/td/session"
tgauth "github.com/gotd/td/telegram/auth"
@ -36,12 +36,13 @@ import (
)
type AuthService struct {
db *gorm.DB
cnf *config.Config
db *gorm.DB
cnf *config.Config
cache *cache.Cache
}
func NewAuthService(db *gorm.DB, cnf *config.Config) *AuthService {
return &AuthService{db: db, cnf: cnf}
func NewAuthService(db *gorm.DB, cnf *config.Config, cache *cache.Cache) *AuthService {
return &AuthService{db: db, cnf: cnf, cache: cache}
}
@ -54,16 +55,16 @@ func (as *AuthService) LogIn(c *gin.Context, session *schemas.TgSession) (*schem
now := time.Now().UTC()
jwtClaims := &types.JWTClaims{Claims: jwt.Claims{
Subject: strconv.FormatInt(session.UserID, 10),
IssuedAt: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(as.cnf.JWT.SessionTime)),
}, TgSession: session.Sesssion,
jwtClaims := &types.JWTClaims{
Name: session.Name,
UserName: session.UserName,
Bot: session.Bot,
IsPremium: session.IsPremium,
}
RegisteredClaims: jwt.RegisteredClaims{
Subject: strconv.FormatInt(session.UserID, 10),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(as.cnf.JWT.SessionTime)),
}}
tokenhash := md5.Sum([]byte(session.Sesssion))
hexToken := hex.EncodeToString(tokenhash[:])
@ -144,43 +145,31 @@ func (as *AuthService) LogIn(c *gin.Context, session *schemas.TgSession) (*schem
func (as *AuthService) GetSession(c *gin.Context) *schemas.Session {
cookie, err := c.Request.Cookie("user-session")
claims, err := auth.VerifyUser(c, as.db, as.cache, as.cnf.JWT.Secret)
if err != nil {
return nil
}
jwePayload, err := auth.Decode(as.cnf.JWT.Secret, cookie.Value)
if err != nil {
return nil
}
cache := cache.FromContext(c)
_, err = getSessionByHash(as.db, cache, jwePayload.Hash)
if err != nil {
return nil
}
claims.TgSession = ""
now := time.Now().UTC()
newExpires := now.Add(as.cnf.JWT.SessionTime)
userId, _ := strconv.ParseInt(jwePayload.Subject, 10, 64)
userId, _ := strconv.ParseInt(claims.Subject, 10, 64)
session := &schemas.Session{Name: jwePayload.Name,
UserName: jwePayload.UserName,
session := &schemas.Session{Name: claims.Name,
UserName: claims.UserName,
UserId: userId,
Hash: jwePayload.Hash,
Hash: claims.Hash,
Expires: newExpires.Format(time.RFC3339)}
jwePayload.IssuedAt = jwt.NewNumericDate(now)
claims.IssuedAt = jwt.NewNumericDate(now)
jwePayload.Expiry = jwt.NewNumericDate(newExpires)
claims.ExpiresAt = jwt.NewNumericDate(newExpires)
jweToken, err := auth.Encode(as.cnf.JWT.Secret, jwePayload)
jweToken, err := auth.Encode(as.cnf.JWT.Secret, claims)
if err != nil {
return nil

View file

@ -104,21 +104,3 @@ func getBotsToken(ctx context.Context, db *gorm.DB, userID, channelId int64) ([]
return bots, nil
}
func getSessionByHash(db *gorm.DB, cache *cache.Cache, hash string) (*models.Session, error) {
var session models.Session
key := fmt.Sprintf("sessions:%s", hash)
err := cache.Get(key, &session)
if err != nil {
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
return nil, err
}
cache.Set(key, &session, 0)
}
return &session, nil
}

View file

@ -78,10 +78,11 @@ type FileService struct {
db *gorm.DB
cnf *config.Config
worker *tgc.StreamWorker
cache *cache.Cache
}
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker) *FileService {
return &FileService{db: db, cnf: cnf, worker: worker}
func NewFileService(db *gorm.DB, cnf *config.Config, worker *tgc.StreamWorker, cache *cache.Cache) *FileService {
return &FileService{db: db, cnf: cnf, worker: worker, cache: cache}
}
func (fs *FileService) CreateFile(c *gin.Context, userId int64, fileIn *schemas.FileIn) (*schemas.FileOut, *types.AppError) {
@ -566,8 +567,6 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
authHash := c.Query("hash")
cache := cache.FromContext(c)
var (
session *models.Session
err error
@ -576,7 +575,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
)
if authHash == "" {
user, err = auth.VerifyUser(c, fs.cnf.JWT.Secret)
user, err = auth.VerifyUser(c, fs.db, fs.cache, fs.cnf.JWT.Secret)
if err != nil {
http.Error(w, "missing session or authash", http.StatusUnauthorized)
return
@ -584,7 +583,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
userId, _ := strconv.ParseInt(user.Subject, 10, 64)
session = &models.Session{UserId: userId, Session: user.TgSession}
} else {
session, err = getSessionByHash(fs.db, cache, authHash)
session, err = auth.GetSessionByHash(fs.db, fs.cache, authHash)
if err != nil {
http.Error(w, "invalid hash", http.StatusBadRequest)
return
@ -595,7 +594,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
key := fmt.Sprintf("files:%s", fileID)
err = cache.Get(key, file)
err = fs.cache.Get(key, file)
if err != nil {
file, appErr = fs.GetFileByID(fileID)
@ -603,7 +602,7 @@ func (fs *FileService) GetFileStream(c *gin.Context, download bool) {
http.Error(w, appErr.Error.Error(), http.StatusBadRequest)
return
}
cache.Set(key, file, 0)
fs.cache.Set(key, file, 0)
}
c.Header("Accept-Ranges", "bytes")

View file

@ -21,7 +21,7 @@ type FileServiceSuite struct {
func (s *FileServiceSuite) SetupSuite() {
s.db = database.NewTestDatabase(s.T(), false)
s.srv = NewFileService(s.db, nil, nil)
s.srv = NewFileService(s.db, nil, nil, nil)
}
func (s *FileServiceSuite) SetupTest() {

View file

@ -1,7 +1,7 @@
package types
import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/golang-jwt/jwt/v5"
"github.com/gotd/td/session"
)
@ -18,13 +18,13 @@ type Part struct {
}
type JWTClaims struct {
jwt.Claims
TgSession string `json:"tgSession"`
jwt.RegisteredClaims
Name string `json:"name"`
UserName string `json:"userName"`
Bot bool `json:"bot"`
IsPremium bool `json:"isPremium"`
Hash string `json:"hash"`
TgSession string `json:"tgSession,omitempty"`
}
type SessionData struct {