mirror of
https://github.com/knadh/listmonk.git
synced 2025-09-13 09:54:37 +08:00
Change OIDC init to lazy-load instead of loading once on boot. Fixes #2626.
This commit is contained in:
parent
961116468b
commit
7d38890868
2 changed files with 95 additions and 30 deletions
10
cmd/auth.go
10
cmd/auth.go
|
@ -34,11 +34,11 @@ type oidcState struct {
|
||||||
Next string `json:"next"`
|
Next string `json:"next"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var oidcProviders = map[string]bool{
|
var oidcProviders = map[string]struct{}{
|
||||||
"google.com": true,
|
"google.com": {},
|
||||||
"microsoftonline.com": true,
|
"microsoftonline.com": {},
|
||||||
"auth0.com": true,
|
"auth0.com": {},
|
||||||
"github.com": true,
|
"github.com": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginPage renders the login page and handles the login form.
|
// LoginPage renders the login page and handles the login form.
|
||||||
|
|
|
@ -84,27 +84,6 @@ func New(cfg Config, db *sql.DB, cb *Callbacks, lo *log.Logger) (*Auth, error) {
|
||||||
apiUsers: map[string]User{},
|
apiUsers: map[string]User{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize OIDC.
|
|
||||||
if cfg.OIDC.Enabled {
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), cfg.OIDC.ProviderURL)
|
|
||||||
if err != nil {
|
|
||||||
cfg.OIDC.Enabled = false
|
|
||||||
lo.Printf("error initializing OIDC OAuth provider: %v", err)
|
|
||||||
} else {
|
|
||||||
a.verifier = provider.Verifier(&oidc.Config{
|
|
||||||
ClientID: cfg.OIDC.ClientID,
|
|
||||||
})
|
|
||||||
|
|
||||||
a.oauthCfg = oauth2.Config{
|
|
||||||
ClientID: cfg.OIDC.ClientID,
|
|
||||||
ClientSecret: cfg.OIDC.ClientSecret,
|
|
||||||
Endpoint: provider.Endpoint(),
|
|
||||||
RedirectURL: cfg.OIDC.RedirectURL,
|
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
||||||
}
|
|
||||||
a.provider = provider
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize session manager.
|
// Initialize session manager.
|
||||||
a.sess = simplesessions.New(simplesessions.Options{
|
a.sess = simplesessions.New(simplesessions.Options{
|
||||||
|
@ -167,15 +146,91 @@ func (o *Auth) GetAPIToken(user string, token string) (User, bool) {
|
||||||
return t, true
|
return t, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initOIDC initializes the OIDC provider, verifier, and OAuth config.
|
||||||
|
func (o *Auth) initOIDC() error {
|
||||||
|
if !o.cfg.OIDC.Enabled {
|
||||||
|
return fmt.Errorf("OIDC is not enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := oidc.NewProvider(context.Background(), o.cfg.OIDC.ProviderURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error initializing OIDC OAuth provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.verifier = provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: o.cfg.OIDC.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
o.oauthCfg = oauth2.Config{
|
||||||
|
ClientID: o.cfg.OIDC.ClientID,
|
||||||
|
ClientSecret: o.cfg.OIDC.ClientSecret,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
RedirectURL: o.cfg.OIDC.RedirectURL,
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
|
}
|
||||||
|
o.provider = provider
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProvider returns the OIDC provider, initializing it if necessary.
|
||||||
|
func (o *Auth) getProvider() (*oidc.Provider, error) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
if o.provider == nil {
|
||||||
|
if err := o.initOIDC(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o.provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVerifier returns the OIDC verifier, initializing it if necessary.
|
||||||
|
func (o *Auth) getVerifier() (*oidc.IDTokenVerifier, error) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
if o.verifier == nil {
|
||||||
|
if err := o.initOIDC(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o.verifier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOAuthConfig returns the OAuth config, initializing it if necessary.
|
||||||
|
func (o *Auth) getOAuthConfig() (*oauth2.Config, error) {
|
||||||
|
o.Lock()
|
||||||
|
defer o.Unlock()
|
||||||
|
|
||||||
|
if o.oauthCfg.ClientID == "" {
|
||||||
|
if err := o.initOIDC(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &o.oauthCfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetOIDCAuthURL returns the OIDC provider's auth URL to redirect to.
|
// GetOIDCAuthURL returns the OIDC provider's auth URL to redirect to.
|
||||||
func (o *Auth) GetOIDCAuthURL(state, nonce string) string {
|
func (o *Auth) GetOIDCAuthURL(state, nonce string) string {
|
||||||
return o.oauthCfg.AuthCodeURL(state, oidc.Nonce(nonce))
|
cfg, err := o.getOAuthConfig()
|
||||||
|
if err != nil {
|
||||||
|
o.log.Printf("error getting OAuth config: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cfg.AuthCodeURL(state, oidc.Nonce(nonce))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeOIDCToken takes an OIDC authorization code (recieved via redirect from the OIDC provider),
|
// ExchangeOIDCToken takes an OIDC authorization code (recieved via redirect from the OIDC provider),
|
||||||
// validates it, and returns an OIDC token for subsequent auth.
|
// validates it, and returns an OIDC token for subsequent auth.
|
||||||
func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error) {
|
func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error) {
|
||||||
tk, err := o.oauthCfg.Exchange(context.TODO(), code)
|
cfg, err := o.getOAuthConfig()
|
||||||
|
if err != nil {
|
||||||
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting OAuth config: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
tk, err := cfg.Exchange(context.TODO(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
|
||||||
}
|
}
|
||||||
|
@ -185,7 +240,12 @@ func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error)
|
||||||
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
|
||||||
}
|
}
|
||||||
|
|
||||||
idTk, err := o.verifier.Verify(context.TODO(), rawIDTk)
|
verifier, err := o.getVerifier()
|
||||||
|
if err != nil {
|
||||||
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error getting verifier: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
idTk, err := verifier.Verify(context.TODO(), rawIDTk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
|
return "", OIDCclaim{}, echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
|
||||||
}
|
}
|
||||||
|
@ -201,7 +261,12 @@ func (o *Auth) ExchangeOIDCToken(code, nonce string) (string, OIDCclaim, error)
|
||||||
|
|
||||||
// If claims doesn't have the e-mail, attempt to fetch it from the userinfo endpoint.
|
// If claims doesn't have the e-mail, attempt to fetch it from the userinfo endpoint.
|
||||||
if claims.Email == "" {
|
if claims.Email == "" {
|
||||||
userInfo, err := o.provider.UserInfo(context.TODO(), oauth2.StaticTokenSource(tk))
|
provider, err := o.getProvider()
|
||||||
|
if err != nil {
|
||||||
|
return "", OIDCclaim{}, fmt.Errorf("error getting provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userInfo, err := provider.UserInfo(context.TODO(), oauth2.StaticTokenSource(tk))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", OIDCclaim{}, errors.New("error fetching user info from OIDC")
|
return "", OIDCclaim{}, errors.New("error fetching user info from OIDC")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue