mirror of
https://github.com/knadh/listmonk.git
synced 2025-10-08 22:37:22 +08:00
455 lines
13 KiB
Go
455 lines
13 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/zerodha/simplesessions/stores/postgres/v3"
|
|
"github.com/zerodha/simplesessions/v3"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type OIDCclaim struct {
|
|
Email string `json:"email"`
|
|
EmailVerified bool `json:"email_verified"`
|
|
Sub string `json:"sub"`
|
|
Picture string `json:"picture"`
|
|
Name string `json:"name"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
}
|
|
|
|
type OIDCConfig struct {
|
|
Enabled bool `json:"enabled"`
|
|
ProviderURL string `json:"provider_url"`
|
|
RedirectURL string `json:"redirect_url"`
|
|
ClientID string `json:"client_id"`
|
|
ClientSecret string `json:"client_secret"`
|
|
AutoCreateUsers bool `json:"auto_create_users"`
|
|
DefaultUserRoleID int `json:"default_user_role_id"`
|
|
DefaultListRoleID int `json:"default_list_role_id"`
|
|
}
|
|
|
|
type BasicAuthConfig struct {
|
|
Enabled bool `json:"enabled"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type Config struct {
|
|
OIDC OIDCConfig
|
|
BasicAuth BasicAuthConfig
|
|
}
|
|
|
|
// Callbacks takes two callback functions required by simplesessions.
|
|
type Callbacks struct {
|
|
SetCookie func(cookie *http.Cookie, w any) error
|
|
GetCookie func(name string, r any) (*http.Cookie, error)
|
|
GetUser func(id int) (User, error)
|
|
}
|
|
|
|
type Auth struct {
|
|
apiUsers map[string]User
|
|
sync.RWMutex
|
|
|
|
cfg Config
|
|
oauthCfg oauth2.Config
|
|
verifier *oidc.IDTokenVerifier
|
|
provider *oidc.Provider
|
|
sess *simplesessions.Manager
|
|
sessStore *postgres.Store
|
|
cb *Callbacks
|
|
log *log.Logger
|
|
}
|
|
|
|
var sessPruneInterval = time.Hour * 12
|
|
|
|
// New returns an initialize Auth instance.
|
|
func New(cfg Config, db *sql.DB, cb *Callbacks, lo *log.Logger) (*Auth, error) {
|
|
a := &Auth{
|
|
cfg: cfg,
|
|
cb: cb,
|
|
log: lo,
|
|
|
|
apiUsers: map[string]User{},
|
|
}
|
|
|
|
|
|
// Initialize session manager.
|
|
a.sess = simplesessions.New(simplesessions.Options{
|
|
EnableAutoCreate: false,
|
|
SessionIDLength: 64,
|
|
Cookie: simplesessions.CookieOptions{
|
|
IsHTTPOnly: true,
|
|
MaxAge: time.Hour * 24 * 7,
|
|
},
|
|
})
|
|
st, err := postgres.New(postgres.Opt{}, db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
a.sessStore = st
|
|
a.sess.UseStore(st)
|
|
a.sess.SetCookieHooks(cb.GetCookie, cb.SetCookie)
|
|
|
|
// Prune dead sessions from the DB periodically.
|
|
go func() {
|
|
if err := st.Prune(); err != nil {
|
|
lo.Printf("error pruning login sessions: %v", err)
|
|
}
|
|
time.Sleep(sessPruneInterval)
|
|
}()
|
|
|
|
return a, nil
|
|
}
|
|
|
|
// CacheAPIUsers caches API users for authenticating requests. It wipes
|
|
// the existing cache every time and is meant for syncing all API users
|
|
// in the database in one shot.
|
|
func (o *Auth) CacheAPIUsers(users []User) {
|
|
o.Lock()
|
|
defer o.Unlock()
|
|
|
|
o.apiUsers = map[string]User{}
|
|
for _, u := range users {
|
|
o.apiUsers[u.Username] = u
|
|
}
|
|
}
|
|
|
|
// CacheAPIUser caches an API user for authenticating requests.
|
|
func (o *Auth) CacheAPIUser(u User) {
|
|
o.Lock()
|
|
o.apiUsers[u.Username] = u
|
|
o.Unlock()
|
|
}
|
|
|
|
// GetAPIToken validates an API user+token.
|
|
func (o *Auth) GetAPIToken(user string, token string) (User, bool) {
|
|
o.RLock()
|
|
t, ok := o.apiUsers[user]
|
|
o.RUnlock()
|
|
|
|
if !ok || subtle.ConstantTimeCompare([]byte(t.Password.String), []byte(token)) != 1 {
|
|
return User{}, false
|
|
}
|
|
|
|
return t, true
|
|
}
|
|
|
|
// initOIDC initializes the OIDC provider, verifier, and OAuth config.
|
|
func (o *Auth) initOIDC() error {
|
|
if !o.cfg.OIDC.Enabled {
|
|
return fmt.Errorf("OIDC is not enabled")
|
|
}
|
|
|
|
provider, err := oidc.NewProvider(context.Background(), o.cfg.OIDC.ProviderURL)
|
|
if err != nil {
|
|
return fmt.Errorf("error initializing OIDC OAuth provider: %v", err)
|
|
}
|
|
|
|
o.verifier = provider.Verifier(&oidc.Config{
|
|
ClientID: o.cfg.OIDC.ClientID,
|
|
})
|
|
|
|
o.oauthCfg = oauth2.Config{
|
|
ClientID: o.cfg.OIDC.ClientID,
|
|
ClientSecret: o.cfg.OIDC.ClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: o.cfg.OIDC.RedirectURL,
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
o.provider = provider
|
|
|
|
return nil
|
|
}
|
|
|
|
// getProvider returns the OIDC provider, initializing it if necessary.
|
|
func (o *Auth) getProvider() (*oidc.Provider, error) {
|
|
o.Lock()
|
|
defer o.Unlock()
|
|
|
|
if o.provider == nil {
|
|
if err := o.initOIDC(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return o.provider, nil
|
|
}
|
|
|
|
// getVerifier returns the OIDC verifier, initializing it if necessary.
|
|
func (o *Auth) getVerifier() (*oidc.IDTokenVerifier, error) {
|
|
o.Lock()
|
|
defer o.Unlock()
|
|
|
|
if o.verifier == nil {
|
|
if err := o.initOIDC(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return o.verifier, nil
|
|
}
|
|
|
|
// getOAuthConfig returns the OAuth config, initializing it if necessary.
|
|
func (o *Auth) getOAuthConfig() (*oauth2.Config, error) {
|
|
o.Lock()
|
|
defer o.Unlock()
|
|
|
|
if o.oauthCfg.ClientID == "" {
|
|
if err := o.initOIDC(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &o.oauthCfg, nil
|
|
}
|
|
|
|
// GetOIDCAuthURL returns the OIDC provider's auth URL to redirect to.
|
|
func (o *Auth) GetOIDCAuthURL(state, nonce string) string {
|
|
cfg, err := o.getOAuthConfig()
|
|
if err != nil {
|
|
o.log.Printf("error getting OAuth config: %v", err)
|
|
return ""
|
|
}
|
|
return cfg.AuthCodeURL(state, oidc.Nonce(nonce))
|
|
}
|
|
|
|
// ExchangeOIDCToken takes an OIDC authorization code (recieved via redirect from the OIDC provider),
|
|
// validates it, and returns an OIDC token for subsequent auth.
|
|
func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error) {
|
|
cfg, err := o.getOAuthConfig()
|
|
if err != nil {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting OAuth config: %v", err))
|
|
}
|
|
|
|
tk, err := cfg.Exchange(context.TODO(), code)
|
|
if err != nil {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
|
|
}
|
|
|
|
rawIDTk, ok := tk.Extra("id_token").(string)
|
|
if !ok {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
|
|
}
|
|
|
|
verifier, err := o.getVerifier()
|
|
if err != nil {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting verifier: %v", err))
|
|
}
|
|
|
|
idTk, err := verifier.Verify(context.TODO(), rawIDTk)
|
|
if err != nil {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
|
|
}
|
|
|
|
if idTk.Nonce != nonce {
|
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match")
|
|
}
|
|
|
|
var claims OIDCclaim
|
|
if err := idTk.Claims(&claims); err != nil {
|
|
return "", OIDCclaim{}, errors.New("error getting user from OIDC")
|
|
}
|
|
|
|
// If claims doesn't have the e-mail, attempt to fetch it from the userinfo endpoint.
|
|
if claims.Email == "" {
|
|
provider, err := o.getProvider()
|
|
if err != nil {
|
|
return "", OIDCclaim{}, fmt.Errorf("error getting provider: %v", err)
|
|
}
|
|
|
|
userInfo, err := provider.UserInfo(context.TODO(), oauth2.StaticTokenSource(tk))
|
|
if err != nil {
|
|
return "", OIDCclaim{}, errors.New("error fetching user info from OIDC")
|
|
}
|
|
|
|
// Parse the UserInfo claims into the claims struct
|
|
if err := userInfo.Claims(&claims); err != nil {
|
|
return "", OIDCclaim{}, errors.New("error parsing user info claims")
|
|
}
|
|
}
|
|
|
|
return rawIDTk, claims, nil
|
|
}
|
|
|
|
// Middleware is the HTTP middleware used for wrapping HTTP handlers registered on the echo router.
|
|
// It authorizes token (BasicAuth/token) based and cookie based sessions and on successful auth,
|
|
// sets the authenticated User{} on the echo context on the key UserKey. On failure, it sets an Error{}
|
|
// instead on the same key.
|
|
func (o *Auth) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
// It's an `Authorization` header request.
|
|
hdr := strings.TrimSpace(c.Request().Header.Get("Authorization"))
|
|
|
|
// If cookie is set, ignore BasicAuth. This is to preserve backwards compatibility
|
|
// in v3 -> v4 upgrade where the user browser sessions would still have old
|
|
// BasicAuth credentials, which no longer work in the new system which expects
|
|
// session cookies instead, which causes a redirect loop despite loggin in and session
|
|
// cookies being set.
|
|
//
|
|
// TODO: This should be removed in a future version.
|
|
if c := strings.TrimSpace(c.Request().Header.Get("Cookie")); strings.Contains(c, "session=") {
|
|
hdr = ""
|
|
}
|
|
|
|
if len(hdr) > 0 {
|
|
key, token, err := parseAuthHeader(hdr)
|
|
if err != nil {
|
|
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, err.Error()))
|
|
return next(c)
|
|
}
|
|
|
|
// Validate the token.
|
|
user, ok := o.GetAPIToken(key, token)
|
|
if !ok {
|
|
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid API credentials"))
|
|
return next(c)
|
|
}
|
|
|
|
// Set the user details on the handler context.
|
|
c.Set(UserHTTPCtxKey, user)
|
|
return next(c)
|
|
}
|
|
|
|
// Is it a cookie based session?
|
|
sess, user, err := o.validateSession(c)
|
|
if err != nil {
|
|
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
|
|
return next(c)
|
|
}
|
|
|
|
// Set the user details on the handler context.
|
|
c.Set(UserHTTPCtxKey, user)
|
|
c.Set(SessionKey, sess)
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// Perm is an HTTP handler middleware that checks if the authenticated user has the required permissions.
|
|
func (o *Auth) Perm(next echo.HandlerFunc, perms ...string) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
u, ok := c.Get(UserHTTPCtxKey).(User)
|
|
if !ok {
|
|
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
|
|
return next(c)
|
|
}
|
|
|
|
// If the current user is a Super Admin user, do no checks.
|
|
if u.UserRole.ID == SuperAdminRoleID {
|
|
return next(c)
|
|
}
|
|
|
|
// Check if the current handler's permission is in the user's permission map.
|
|
var (
|
|
has = false
|
|
perm = ""
|
|
)
|
|
for _, perm = range perms {
|
|
if _, ok := u.PermissionsMap[perm]; ok {
|
|
has = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !has {
|
|
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("permission denied: %s", perm))
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
// SaveSession creates and sets a session (post successful login/auth).
|
|
func (o *Auth) SaveSession(u User, oidcToken string, c echo.Context) error {
|
|
sess, err := o.sess.NewSession(c, c)
|
|
if err != nil {
|
|
o.log.Printf("error creating login session: %v", err)
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
|
|
}
|
|
|
|
if err := sess.SetMulti(map[string]any{"user_id": u.ID, "oidc_token": oidcToken}); err != nil {
|
|
o.log.Printf("error setting login session: %v", err)
|
|
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateSession checks if the cookie session is valid (in the DB) and returns the session and user details.
|
|
func (o *Auth) validateSession(c echo.Context) (*simplesessions.Session, User, error) {
|
|
// Cookie session.
|
|
sess, err := o.sess.Acquire(context.TODO(), c, c)
|
|
if err != nil {
|
|
return nil, User{}, echo.NewHTTPError(http.StatusForbidden, err.Error())
|
|
}
|
|
|
|
// Get the session variables.
|
|
vars, err := sess.GetMulti("user_id", "oidc_token")
|
|
if err != nil {
|
|
return nil, User{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
|
|
}
|
|
|
|
// Validate the user ID in the session.
|
|
userID, err := o.sessStore.Int(vars["user_id"], nil)
|
|
if err != nil || userID < 1 {
|
|
o.log.Printf("error fetching session user ID: %v", err)
|
|
return nil, User{}, echo.NewHTTPError(http.StatusInternalServerError, "invalid session.")
|
|
}
|
|
|
|
// Fetch user details from the database.
|
|
user, err := o.cb.GetUser(userID)
|
|
if err != nil {
|
|
o.log.Printf("error fetching session user: %v", err)
|
|
}
|
|
|
|
return sess, user, err
|
|
}
|
|
|
|
// GetUser retrieves and returns the User object from an authenticated
|
|
// HTTP handler request.
|
|
func GetUser(c echo.Context) User {
|
|
return c.Get(UserHTTPCtxKey).(User)
|
|
}
|
|
|
|
// parseAuthHeader parses the Authorization header and returns the api_key and access_token.
|
|
func parseAuthHeader(h string) (string, string, error) {
|
|
const authBasic = "Basic"
|
|
const authToken = "token"
|
|
|
|
var (
|
|
pair []string
|
|
delim = ":"
|
|
)
|
|
|
|
if strings.HasPrefix(h, authToken) {
|
|
// token api_key:access_token.
|
|
pair = strings.SplitN(strings.Trim(h[len(authToken):], " "), delim, 2)
|
|
} else if strings.HasPrefix(h, authBasic) {
|
|
// HTTP BasicAuth. This is supported for backwards compatibility.
|
|
payload, err := base64.StdEncoding.DecodeString(string(strings.Trim(h[len(authBasic):], " ")))
|
|
if err != nil {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "invalid Base64 value in Basic Authorization header")
|
|
}
|
|
pair = strings.SplitN(string(payload), delim, 2)
|
|
} else {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "unknown Authorization scheme")
|
|
}
|
|
|
|
if len(pair) < 2 {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "api_key:token missing")
|
|
}
|
|
|
|
if len(pair[0]) == 0 || len(pair[1]) == 0 {
|
|
return "", "", echo.NewHTTPError(http.StatusBadRequest, "empty `api_key` or `token`")
|
|
}
|
|
|
|
return pair[0], pair[1], nil
|
|
}
|