teldrive/internal/auth/auth.go

119 lines
3 KiB
Go

package auth
import (
"context"
"fmt"
"strconv"
"github.com/golang-jwt/jwt/v5"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/cache"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/pkg/models"
"github.com/tgdrive/teldrive/pkg/types"
"gorm.io/gorm"
)
type authContextKey string
const authKey authContextKey = "authUser"
func Encode(secret string, claims *types.JWTClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret))
}
func Decode(secret string, token string) (*types.JWTClaims, error) {
claims := &types.JWTClaims{}
tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return nil, err
}
if !tkn.Valid {
return nil, fmt.Errorf("invalid token")
}
return claims, err
}
func GetUser(c context.Context) (int64, string) {
authUser, _ := c.Value(authKey).(*types.JWTClaims)
userId, _ := strconv.ParseInt(authUser.Subject, 10, 64)
return userId, authUser.TgSession
}
func GetJWTUser(c context.Context) *types.JWTClaims {
authUser, _ := c.Value(authKey).(*types.JWTClaims)
return authUser
}
func VerifyUser(db *gorm.DB, cache cache.Cacher, secret, authCookie string) (*types.JWTClaims, error) {
claims, err := Decode(secret, authCookie)
if err != nil {
return nil, err
}
var session *models.Session
session, err = GetSessionByHash(db, cache, claims.Hash)
if err != nil {
return nil, fmt.Errorf("invalid session")
}
claims.TgSession = session.Session
return claims, nil
}
func GetSessionByHash(db *gorm.DB, cache cache.Cacher, hash string) (*models.Session, error) {
var session models.Session
key := fmt.Sprintf("sessions:%s", hash)
err := cache.Get(key, &session)
if err != nil {
if err := db.Model(&models.Session{}).Where("hash = ?", hash).First(&session).Error; err != nil {
return nil, err
}
cache.Set(key, &session, 0)
}
return &session, nil
}
type securityHandler struct {
db *gorm.DB
cache cache.Cacher
cfg *config.Config
}
func (s *securityHandler) HandleApiKeyAuth(ctx context.Context, operationName api.OperationName, t api.ApiKeyAuth) (context.Context, error) {
return s.handleAuth(ctx, t.APIKey)
}
func (s *securityHandler) HandleBearerAuth(ctx context.Context, operationName api.OperationName, t api.BearerAuth) (context.Context, error) {
return s.handleAuth(ctx, t.Token)
}
func (s *securityHandler) handleAuth(ctx context.Context, token string) (context.Context, error) {
claims, err := VerifyUser(s.db, s.cache, s.cfg.JWT.Secret, token)
if err != nil {
return nil, &ogenerrors.SecurityError{Err: err}
}
return context.WithValue(ctx, authKey, claims), nil
}
func NewSecurityHandler(db *gorm.DB, cache cache.Cacher, cfg *config.Config) api.SecurityHandler {
return &securityHandler{db: db, cache: cache, cfg: cfg}
}
var _ api.SecurityHandler = (*securityHandler)(nil)