listmonk/internal/auth/auth.go

455 lines
13 KiB
Go

package auth
import (
"context"
"crypto/subtle"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/labstack/echo/v4"
"github.com/zerodha/simplesessions/stores/postgres/v3"
"github.com/zerodha/simplesessions/v3"
"golang.org/x/oauth2"
)
type OIDCclaim struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Sub string `json:"sub"`
Picture string `json:"picture"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
}
type OIDCConfig struct {
Enabled bool `json:"enabled"`
ProviderURL string `json:"provider_url"`
RedirectURL string `json:"redirect_url"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
AutoCreateUsers bool `json:"auto_create_users"`
DefaultUserRoleID int `json:"default_user_role_id"`
DefaultListRoleID int `json:"default_list_role_id"`
}
type BasicAuthConfig struct {
Enabled bool `json:"enabled"`
Username string `json:"username"`
Password string `json:"password"`
}
type Config struct {
OIDC OIDCConfig
BasicAuth BasicAuthConfig
}
// Callbacks takes two callback functions required by simplesessions.
type Callbacks struct {
SetCookie func(cookie *http.Cookie, w any) error
GetCookie func(name string, r any) (*http.Cookie, error)
GetUser func(id int) (User, error)
}
type Auth struct {
apiUsers map[string]User
sync.RWMutex
cfg Config
oauthCfg oauth2.Config
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
sess *simplesessions.Manager
sessStore *postgres.Store
cb *Callbacks
log *log.Logger
}
var sessPruneInterval = time.Hour * 12
// New returns an initialize Auth instance.
func New(cfg Config, db *sql.DB, cb *Callbacks, lo *log.Logger) (*Auth, error) {
a := &Auth{
cfg: cfg,
cb: cb,
log: lo,
apiUsers: map[string]User{},
}
// Initialize session manager.
a.sess = simplesessions.New(simplesessions.Options{
EnableAutoCreate: false,
SessionIDLength: 64,
Cookie: simplesessions.CookieOptions{
IsHTTPOnly: true,
MaxAge: time.Hour * 24 * 7,
},
})
st, err := postgres.New(postgres.Opt{}, db)
if err != nil {
return nil, err
}
a.sessStore = st
a.sess.UseStore(st)
a.sess.SetCookieHooks(cb.GetCookie, cb.SetCookie)
// Prune dead sessions from the DB periodically.
go func() {
if err := st.Prune(); err != nil {
lo.Printf("error pruning login sessions: %v", err)
}
time.Sleep(sessPruneInterval)
}()
return a, nil
}
// CacheAPIUsers caches API users for authenticating requests. It wipes
// the existing cache every time and is meant for syncing all API users
// in the database in one shot.
func (o *Auth) CacheAPIUsers(users []User) {
o.Lock()
defer o.Unlock()
o.apiUsers = map[string]User{}
for _, u := range users {
o.apiUsers[u.Username] = u
}
}
// CacheAPIUser caches an API user for authenticating requests.
func (o *Auth) CacheAPIUser(u User) {
o.Lock()
o.apiUsers[u.Username] = u
o.Unlock()
}
// GetAPIToken validates an API user+token.
func (o *Auth) GetAPIToken(user string, token string) (User, bool) {
o.RLock()
t, ok := o.apiUsers[user]
o.RUnlock()
if !ok || subtle.ConstantTimeCompare([]byte(t.Password.String), []byte(token)) != 1 {
return User{}, false
}
return t, true
}
// initOIDC initializes the OIDC provider, verifier, and OAuth config.
func (o *Auth) initOIDC() error {
if !o.cfg.OIDC.Enabled {
return fmt.Errorf("OIDC is not enabled")
}
provider, err := oidc.NewProvider(context.Background(), o.cfg.OIDC.ProviderURL)
if err != nil {
return fmt.Errorf("error initializing OIDC OAuth provider: %v", err)
}
o.verifier = provider.Verifier(&oidc.Config{
ClientID: o.cfg.OIDC.ClientID,
})
o.oauthCfg = oauth2.Config{
ClientID: o.cfg.OIDC.ClientID,
ClientSecret: o.cfg.OIDC.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: o.cfg.OIDC.RedirectURL,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
o.provider = provider
return nil
}
// getProvider returns the OIDC provider, initializing it if necessary.
func (o *Auth) getProvider() (*oidc.Provider, error) {
o.Lock()
defer o.Unlock()
if o.provider == nil {
if err := o.initOIDC(); err != nil {
return nil, err
}
}
return o.provider, nil
}
// getVerifier returns the OIDC verifier, initializing it if necessary.
func (o *Auth) getVerifier() (*oidc.IDTokenVerifier, error) {
o.Lock()
defer o.Unlock()
if o.verifier == nil {
if err := o.initOIDC(); err != nil {
return nil, err
}
}
return o.verifier, nil
}
// getOAuthConfig returns the OAuth config, initializing it if necessary.
func (o *Auth) getOAuthConfig() (*oauth2.Config, error) {
o.Lock()
defer o.Unlock()
if o.oauthCfg.ClientID == "" {
if err := o.initOIDC(); err != nil {
return nil, err
}
}
return &o.oauthCfg, nil
}
// GetOIDCAuthURL returns the OIDC provider's auth URL to redirect to.
func (o *Auth) GetOIDCAuthURL(state, nonce string) string {
cfg, err := o.getOAuthConfig()
if err != nil {
o.log.Printf("error getting OAuth config: %v", err)
return ""
}
return cfg.AuthCodeURL(state, oidc.Nonce(nonce))
}
// ExchangeOIDCToken takes an OIDC authorization code (recieved via redirect from the OIDC provider),
// validates it, and returns an OIDC token for subsequent auth.
func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error) {
cfg, err := o.getOAuthConfig()
if err != nil {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting OAuth config: %v", err))
}
tk, err := cfg.Exchange(context.TODO(), code)
if err != nil {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
}
rawIDTk, ok := tk.Extra("id_token").(string)
if !ok {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
}
verifier, err := o.getVerifier()
if err != nil {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting verifier: %v", err))
}
idTk, err := verifier.Verify(context.TODO(), rawIDTk)
if err != nil {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
}
if idTk.Nonce != nonce {
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match")
}
var claims OIDCclaim
if err := idTk.Claims(&claims); err != nil {
return "", OIDCclaim{}, errors.New("error getting user from OIDC")
}
// If claims doesn't have the e-mail, attempt to fetch it from the userinfo endpoint.
if claims.Email == "" {
provider, err := o.getProvider()
if err != nil {
return "", OIDCclaim{}, fmt.Errorf("error getting provider: %v", err)
}
userInfo, err := provider.UserInfo(context.TODO(), oauth2.StaticTokenSource(tk))
if err != nil {
return "", OIDCclaim{}, errors.New("error fetching user info from OIDC")
}
// Parse the UserInfo claims into the claims struct
if err := userInfo.Claims(&claims); err != nil {
return "", OIDCclaim{}, errors.New("error parsing user info claims")
}
}
return rawIDTk, claims, nil
}
// Middleware is the HTTP middleware used for wrapping HTTP handlers registered on the echo router.
// It authorizes token (BasicAuth/token) based and cookie based sessions and on successful auth,
// sets the authenticated User{} on the echo context on the key UserKey. On failure, it sets an Error{}
// instead on the same key.
func (o *Auth) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// It's an `Authorization` header request.
hdr := strings.TrimSpace(c.Request().Header.Get("Authorization"))
// If cookie is set, ignore BasicAuth. This is to preserve backwards compatibility
// in v3 -> v4 upgrade where the user browser sessions would still have old
// BasicAuth credentials, which no longer work in the new system which expects
// session cookies instead, which causes a redirect loop despite loggin in and session
// cookies being set.
//
// TODO: This should be removed in a future version.
if c := strings.TrimSpace(c.Request().Header.Get("Cookie")); strings.Contains(c, "session=") {
hdr = ""
}
if len(hdr) > 0 {
key, token, err := parseAuthHeader(hdr)
if err != nil {
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, err.Error()))
return next(c)
}
// Validate the token.
user, ok := o.GetAPIToken(key, token)
if !ok {
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid API credentials"))
return next(c)
}
// Set the user details on the handler context.
c.Set(UserHTTPCtxKey, user)
return next(c)
}
// Is it a cookie based session?
sess, user, err := o.validateSession(c)
if err != nil {
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
return next(c)
}
// Set the user details on the handler context.
c.Set(UserHTTPCtxKey, user)
c.Set(SessionKey, sess)
return next(c)
}
}
// Perm is an HTTP handler middleware that checks if the authenticated user has the required permissions.
func (o *Auth) Perm(next echo.HandlerFunc, perms ...string) echo.HandlerFunc {
return func(c echo.Context) error {
u, ok := c.Get(UserHTTPCtxKey).(User)
if !ok {
c.Set(UserHTTPCtxKey, echo.NewHTTPError(http.StatusForbidden, "invalid session"))
return next(c)
}
// If the current user is a Super Admin user, do no checks.
if u.UserRole.ID == SuperAdminRoleID {
return next(c)
}
// Check if the current handler's permission is in the user's permission map.
var (
has = false
perm = ""
)
for _, perm = range perms {
if _, ok := u.PermissionsMap[perm]; ok {
has = true
break
}
}
if !has {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("permission denied: %s", perm))
}
return next(c)
}
}
// SaveSession creates and sets a session (post successful login/auth).
func (o *Auth) SaveSession(u User, oidcToken string, c echo.Context) error {
sess, err := o.sess.NewSession(c, c)
if err != nil {
o.log.Printf("error creating login session: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
}
if err := sess.SetMulti(map[string]any{"user_id": u.ID, "oidc_token": oidcToken}); err != nil {
o.log.Printf("error setting login session: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError, "error creating session")
}
return nil
}
// validateSession checks if the cookie session is valid (in the DB) and returns the session and user details.
func (o *Auth) validateSession(c echo.Context) (*simplesessions.Session, User, error) {
// Cookie session.
sess, err := o.sess.Acquire(context.TODO(), c, c)
if err != nil {
return nil, User{}, echo.NewHTTPError(http.StatusForbidden, err.Error())
}
// Get the session variables.
vars, err := sess.GetMulti("user_id", "oidc_token")
if err != nil {
return nil, User{}, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
// Validate the user ID in the session.
userID, err := o.sessStore.Int(vars["user_id"], nil)
if err != nil || userID < 1 {
o.log.Printf("error fetching session user ID: %v", err)
return nil, User{}, echo.NewHTTPError(http.StatusInternalServerError, "invalid session.")
}
// Fetch user details from the database.
user, err := o.cb.GetUser(userID)
if err != nil {
o.log.Printf("error fetching session user: %v", err)
}
return sess, user, err
}
// GetUser retrieves and returns the User object from an authenticated
// HTTP handler request.
func GetUser(c echo.Context) User {
return c.Get(UserHTTPCtxKey).(User)
}
// parseAuthHeader parses the Authorization header and returns the api_key and access_token.
func parseAuthHeader(h string) (string, string, error) {
const authBasic = "Basic"
const authToken = "token"
var (
pair []string
delim = ":"
)
if strings.HasPrefix(h, authToken) {
// token api_key:access_token.
pair = strings.SplitN(strings.Trim(h[len(authToken):], " "), delim, 2)
} else if strings.HasPrefix(h, authBasic) {
// HTTP BasicAuth. This is supported for backwards compatibility.
payload, err := base64.StdEncoding.DecodeString(string(strings.Trim(h[len(authBasic):], " ")))
if err != nil {
return "", "", echo.NewHTTPError(http.StatusBadRequest, "invalid Base64 value in Basic Authorization header")
}
pair = strings.SplitN(string(payload), delim, 2)
} else {
return "", "", echo.NewHTTPError(http.StatusBadRequest, "unknown Authorization scheme")
}
if len(pair) < 2 {
return "", "", echo.NewHTTPError(http.StatusBadRequest, "api_key:token missing")
}
if len(pair[0]) == 0 || len(pair[1]) == 0 {
return "", "", echo.NewHTTPError(http.StatusBadRequest, "empty `api_key` or `token`")
}
return pair[0], pair[1], nil
}