mirror of
https://github.com/knadh/listmonk.git
synced 2025-03-01 16:55:26 +08:00
372 lines
10 KiB
Go
372 lines
10 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/knadh/listmonk/models"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/echo/v4/middleware"
|
|
"github.com/zerodha/simplesessions/stores/postgres/v3"
|
|
"github.com/zerodha/simplesessions/v3"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const (
|
|
// UserKey is the key on which the User profile is set on echo handlers.
|
|
UserKey = "auth_user"
|
|
SessionKey = "auth_session"
|
|
|
|
SuperAdminRole = 1
|
|
)
|
|
|
|
const (
|
|
sessTypeNative = "native"
|
|
sessTypeOIDC = "oidc"
|
|
)
|
|
|
|
type OIDCConfig struct {
|
|
Enabled bool `json:"enabled"`
|
|
ProviderURL string `json:"provider_url"`
|
|
RedirectURL string `json:"redirect_url"`
|
|
ClientID string `json:"client_id"`
|
|
ClientSecret string `json:"client_secret"`
|
|
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper middleware.Skipper
|
|
}
|
|
|
|
type BasicAuthConfig struct {
|
|
Enabled bool `json:"enabled"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type Config struct {
|
|
OIDC OIDCConfig
|
|
BasicAuth BasicAuthConfig
|
|
}
|
|
|
|
// Callbacks takes two callback functions required by simplesessions.
|
|
type Callbacks struct {
|
|
SetCookie func(cookie *http.Cookie, w interface{}) error
|
|
GetCookie func(name string, r interface{}) (*http.Cookie, error)
|
|
GetUser func(id int) (models.User, error)
|
|
}
|
|
|
|
type Auth struct {
|
|
tokens map[string]models.User
|
|
sync.RWMutex
|
|
|
|
cfg Config
|
|
oauthCfg oauth2.Config
|
|
verifier *oidc.IDTokenVerifier
|
|
skipper middleware.Skipper
|
|
sess *simplesessions.Manager
|
|
sessStore *postgres.Store
|
|
cb *Callbacks
|
|
log *log.Logger
|
|
}
|
|
|
|
func New(cfg Config, db *sql.DB, cb *Callbacks, lo *log.Logger) (*Auth, error) {
|
|
a := &Auth{
|
|
cfg: cfg,
|
|
cb: cb,
|
|
log: lo,
|
|
}
|
|
|
|
// Initialize OIDC.
|
|
if cfg.OIDC.Enabled {
|
|
provider, err := oidc.NewProvider(context.Background(), cfg.OIDC.ProviderURL)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
a.verifier = provider.Verifier(&oidc.Config{
|
|
ClientID: cfg.OIDC.ClientID,
|
|
})
|
|
|
|
a.oauthCfg = oauth2.Config{
|
|
ClientID: cfg.OIDC.ClientID,
|
|
ClientSecret: cfg.OIDC.ClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: cfg.OIDC.RedirectURL,
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
|
|
a.skipper = cfg.OIDC.Skipper
|
|
}
|
|
|
|
// Initialize session manager.
|
|
a.sess = simplesessions.New(simplesessions.Options{
|
|
EnableAutoCreate: false,
|
|
SessionIDLength: 64,
|
|
Cookie: simplesessions.CookieOptions{
|
|
IsHTTPOnly: true,
|
|
MaxAge: time.Hour * 24 * 7,
|
|
},
|
|
})
|
|
st, err := postgres.New(postgres.Opt{}, db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
a.sessStore = st
|
|
a.sess.UseStore(st)
|
|
a.sess.SetCookieHooks(cb.GetCookie, cb.SetCookie)
|
|
|
|
// Prune dead sessions from the DB periodically.
|
|
go func() {
|
|
if err := st.Prune(); err != nil {
|
|
lo.Printf("error pruning login sessions: %v", err)
|
|
}
|
|
time.Sleep(time.Hour * 12)
|
|
}()
|
|
|
|
return a, nil
|
|
}
|
|
|
|
// SetTokens caches tokens for authenticating API client calls.
|
|
func (o *Auth) SetTokens(tokens map[string]models.User) {
|
|
o.Lock()
|
|
defer o.Unlock()
|
|
|
|
o.tokens = make(map[string]models.User, len(tokens))
|
|
for userID, u := range tokens {
|
|
o.tokens[userID] = u
|
|
}
|
|
}
|
|
|
|
// GetToken validates an API user+token.
|
|
func (o *Auth) GetToken(user string, token string) (models.User, bool) {
|
|
o.RLock()
|
|
t, ok := o.tokens[user]
|
|
o.RUnlock()
|
|
|
|
if !ok || subtle.ConstantTimeCompare([]byte(t.Password.String), []byte(token)) != 1 {
|
|
return models.User{}, false
|
|
}
|
|
|
|
return t, true
|
|
}
|
|
|
|
// GetOIDCAuthURL returns the OIDC provider's auth URL to redirect to.
|
|
func (o *Auth) GetOIDCAuthURL(state, nonce string) string {
|
|
return o.oauthCfg.AuthCodeURL(state, oidc.Nonce(nonce))
|
|
}
|
|
|
|
// ExchangeOIDCToken takes an OIDC authorization code (recieved via redirect from the OIDC provider),
|
|
// validates it, and returns an OIDC token for subsequent auth.
|
|
func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, models.User, error) {
|
|
var user models.User
|
|
|
|
tk, err := o.oauthCfg.Exchange(context.TODO(), code)
|
|
if err != nil {
|
|
return "", user, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
|
|
}
|
|
|
|
rawIDTk, ok := tk.Extra("id_token").(string)
|
|
if !ok {
|
|
return "", user, echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
|
|
}
|
|
|
|
idTk, err := o.verifier.Verify(context.TODO(), rawIDTk)
|
|
if err != nil {
|
|
return "", user, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
|
|
}
|
|
|
|
if idTk.Nonce != nonce {
|
|
return "", user, echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match")
|
|
}
|
|
|
|
if err := idTk.Claims(&user); err != nil {
|
|
return "", user, errors.New("error getting user from OIDC")
|
|
}
|
|
|
|
return rawIDTk, user, nil
|
|
}
|
|
|
|
// Middleware is the HTTP middleware used for wrapping HTTP handlers registered on the echo router.
|
|
// It authorizes token (BasicAuth/token) based and cookie based sessions and on successful auth,
|
|
// sets the authenticated User{} on the echo context on the key UserKey. On failure, it sets an Error{}
|
|
// instead on the same key.
|
|
func (o *Auth) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
// It's an `Authorization` header request.
|
|
hdr := c.Response().Header().Get("Authorization")
|
|
if len(hdr) > 0 {
|
|
key, token, err := parseAuthHeader(hdr)
|
|
if err != nil {
|
|
c.Set(UserKey, echo.NewHTTPError(http.StatusForbidden, err.Error()))
|
|
return next(c)
|
|
}
|
|
|
|
// Validate the token.
|
|
user, ok := o.GetToken(key, token)
|
|
if !ok {
|
|
c.Set(UserKey, echo.NewHTTPError(http.StatusForbidden, "invalid token:secret"))
|
|
return next(c)
|
|
}
|
|
|
|
// Set the user details on the handler context.
|
|
c.Set(UserKey, user)
|
|
return next(c)
|
|
}
|
|
|
|
// It's a cookie based session.
|
|
sess, user, err := o.validateSession(c)
|
|
if err != nil {
|
|
c.Set(UserKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
|
|
return next(c)
|
|
}
|
|
|
|
// Set the user details on the handler context.
|
|
c.Set(UserKey, user)
|
|
c.Set(SessionKey, sess)
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
func (o *Auth) Perm(next echo.HandlerFunc, perm string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
u, ok := c.Get(UserKey).(models.User)
|
|
if !ok {
|
|
c.Set(UserKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
|
|
return next(c)
|
|
}
|
|
|
|
// If there's no permission set on the handler or if the current user is a super admin, do no checks.
|
|
if perm == "" || u.RoleID == SuperAdminRole {
|
|
return next(c)
|
|
}
|
|
|
|
// Check if the current handler's permission is in the user's permission map.
|
|
if _, ok := u.PermissionsMap[perm]; !ok {
|
|
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("permission denied (%s)", perm))
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// SetSession creates and sets a session (post successful login/auth).
|
|
func (o *Auth) SetSession(u models.User, oidcToken string, c echo.Context) error {
|
|
// sess, err := o.sess.Acquire(nil, c, c)
|
|
// if err != nil {
|
|
// return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
|
|
// }
|
|
|
|
sess, err := o.sess.NewSession(c, c)
|
|
if err != nil {
|
|
o.log.Printf("error creating login session: %v", err)
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
|
|
}
|
|
|
|
if err := sess.SetMulti(map[string]interface{}{"user_id": u.ID, "oidc_token": oidcToken}); err != nil {
|
|
o.log.Printf("error setting login session: %v", err)
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (o *Auth) validateSession(c echo.Context) (*simplesessions.Session, models.User, error) {
|
|
// Cookie session.
|
|
sess, err := o.sess.Acquire(nil, c, c)
|
|
if err != nil {
|
|
return nil, models.User{}, echo.NewHTTPError(http.StatusForbidden, err.Error())
|
|
}
|
|
|
|
// Get the session variables.
|
|
vars, err := sess.GetMulti("user_id", "oidc_token")
|
|
if err != nil {
|
|
return nil, models.User{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
|
|
// Validate the user ID in the session.
|
|
userID, err := o.sessStore.Int(vars["user_id"], nil)
|
|
if err != nil || userID < 1 {
|
|
o.log.Printf("error fetching session user ID: %v", err)
|
|
return nil, models.User{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
|
|
// If it's an OIDC session, validate the claim.
|
|
if vars["oidc_token"] != "" {
|
|
if !o.cfg.OIDC.Enabled {
|
|
return nil, models.User{}, echo.NewHTTPError(http.StatusForbidden, "OIDC aut his not enabled.")
|
|
}
|
|
if _, err := o.verifyOIDC(vars["oidc_token"].(string), c); err != nil {
|
|
return nil, models.User{}, err
|
|
}
|
|
}
|
|
|
|
// Fetch user details from the database.
|
|
user, err := o.cb.GetUser(userID)
|
|
if err != nil {
|
|
o.log.Printf("error fetching session user: %v", err)
|
|
}
|
|
|
|
return sess, user, err
|
|
}
|
|
|
|
func (o *Auth) verifyOIDC(token string, c echo.Context) (models.User, error) {
|
|
idTk, err := o.verifier.Verify(c.Request().Context(), token)
|
|
if err != nil {
|
|
return models.User{}, err
|
|
}
|
|
|
|
var user models.User
|
|
if err := idTk.Claims(&user); err != nil {
|
|
return user, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("error verifying OIDC claim: %v", user))
|
|
}
|
|
|
|
if user.ID < 1 {
|
|
return user, echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("invalid user ID in OIDC: %v", user))
|
|
}
|
|
|
|
return user, err
|
|
}
|
|
|
|
// parseAuthHeader parses the Authorization header and returns the api_key and access_token.
|
|
func parseAuthHeader(h string) (string, string, error) {
|
|
const authBasic = "Basic"
|
|
const authToken = "token"
|
|
|
|
var (
|
|
pair []string
|
|
delim = ":"
|
|
)
|
|
|
|
if strings.HasPrefix(h, authToken) {
|
|
// token api_key:access_token.
|
|
pair = strings.SplitN(strings.Trim(h[len(authToken):], " "), delim, 2)
|
|
} else if strings.HasPrefix(h, authBasic) {
|
|
// HTTP BasicAuth. This is supported for backwards compatibility.
|
|
payload, err := base64.StdEncoding.DecodeString(string(strings.Trim(h[len(authBasic):], " ")))
|
|
if err != nil {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "invalid Base64 value in Basic Authorization header")
|
|
}
|
|
pair = strings.SplitN(string(payload), delim, 2)
|
|
} else {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "unknown Authorization scheme")
|
|
}
|
|
|
|
if len(pair) < 2 {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "api_key:token missing")
|
|
}
|
|
|
|
if len(pair[0]) == 0 || len(pair[1]) == 0 {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "empty `api_key` or `token`")
|
|
}
|
|
|
|
return pair[0], pair[1], nil
|
|
}
|