mirror of
https://github.com/tgdrive/teldrive.git
synced 2025-01-31 03:19:19 +08:00
119 lines
3 KiB
Go
119 lines
3 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/ogen-go/ogen/ogenerrors"
|
|
"github.com/tgdrive/teldrive/internal/api"
|
|
"github.com/tgdrive/teldrive/internal/cache"
|
|
"github.com/tgdrive/teldrive/internal/config"
|
|
"github.com/tgdrive/teldrive/pkg/models"
|
|
"github.com/tgdrive/teldrive/pkg/types"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type authContextKey string
|
|
|
|
const authKey authContextKey = "authUser"
|
|
|
|
func Encode(secret string, claims *types.JWTClaims) (string, error) {
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
return token.SignedString([]byte(secret))
|
|
}
|
|
|
|
func Decode(secret string, token string) (*types.JWTClaims, error) {
|
|
claims := &types.JWTClaims{}
|
|
|
|
tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(secret), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !tkn.Valid {
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
return claims, err
|
|
|
|
}
|
|
|
|
func GetUser(c context.Context) (int64, string) {
|
|
authUser, _ := c.Value(authKey).(*types.JWTClaims)
|
|
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
|
|
return userId, authUser.TgSession
|
|
}
|
|
|
|
func GetJWTUser(c context.Context) *types.JWTClaims {
|
|
authUser, _ := c.Value(authKey).(*types.JWTClaims)
|
|
return authUser
|
|
}
|
|
|
|
func VerifyUser(db *gorm.DB, cache cache.Cacher, secret, authCookie string) (*types.JWTClaims, error) {
|
|
claims, err := Decode(secret, authCookie)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var session *models.Session
|
|
|
|
session, err = GetSessionByHash(db, cache, claims.Hash)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid session")
|
|
}
|
|
|
|
claims.TgSession = session.Session
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
func GetSessionByHash(db *gorm.DB, cache cache.Cacher, hash string) (*models.Session, error) {
|
|
var session models.Session
|
|
key := fmt.Sprintf("sessions:%s", hash)
|
|
|
|
err := cache.Get(key, &session)
|
|
|
|
if err != nil {
|
|
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
cache.Set(key, &session, 0)
|
|
}
|
|
|
|
return &session, nil
|
|
|
|
}
|
|
|
|
type securityHandler struct {
|
|
db *gorm.DB
|
|
cache cache.Cacher
|
|
cfg *config.Config
|
|
}
|
|
|
|
func (s *securityHandler) HandleApiKeyAuth(ctx context.Context, operationName api.OperationName, t api.ApiKeyAuth) (context.Context, error) {
|
|
return s.handleAuth(ctx, t.APIKey)
|
|
}
|
|
|
|
func (s *securityHandler) HandleBearerAuth(ctx context.Context, operationName api.OperationName, t api.BearerAuth) (context.Context, error) {
|
|
return s.handleAuth(ctx, t.Token)
|
|
}
|
|
|
|
func (s *securityHandler) handleAuth(ctx context.Context, token string) (context.Context, error) {
|
|
claims, err := VerifyUser(s.db, s.cache, s.cfg.JWT.Secret, token)
|
|
if err != nil {
|
|
return nil, &ogenerrors.SecurityError{Err: err}
|
|
}
|
|
return context.WithValue(ctx, authKey, claims), nil
|
|
}
|
|
|
|
func NewSecurityHandler(db *gorm.DB, cache cache.Cacher, cfg *config.Config) api.SecurityHandler {
|
|
return &securityHandler{db: db, cache: cache, cfg: cfg}
|
|
}
|
|
|
|
var _ api.SecurityHandler = (*securityHandler)(nil)
|