Refactor the oidc package and separate out handlers.

This commit is contained in:
Kailash Nadh 2024-04-02 14:30:12 +05:30
parent 8ca95f6827
commit f8b3ddb5ee

View file

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -25,52 +24,51 @@ type Config struct {
Skipper middleware.Skipper Skipper middleware.Skipper
} }
func OIDCAuth(config Config) echo.MiddlewareFunc { type OIDC struct {
provider, err := oidc.NewProvider(context.Background(), config.ProviderURL) cfg oauth2.Config
verifier *oidc.IDTokenVerifier
skipper middleware.Skipper
}
func New(cfg Config) *OIDC {
provider, err := oidc.NewProvider(context.Background(), cfg.ProviderURL)
if err != nil { if err != nil {
panic(err) panic(err)
} }
verifier := provider.Verifier(&oidc.Config{ verifier := provider.Verifier(&oidc.Config{
ClientID: config.ClientID, ClientID: cfg.ClientID,
}) })
oidcConfig := oauth2.Config{ oidcConfig := oauth2.Config{
ClientID: config.ClientID, ClientID: cfg.ClientID,
ClientSecret: config.ClientSecret, ClientSecret: cfg.ClientSecret,
Endpoint: provider.Endpoint(), Endpoint: provider.Endpoint(),
RedirectURL: config.RedirectURL, RedirectURL: cfg.RedirectURL,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
} }
pathURL, err := url.Parse(config.RedirectURL) return &OIDC{
verifier: verifier,
cfg: oidcConfig,
skipper: cfg.Skipper,
}
}
// HandleCallback is the HTTP handler that handles the post-OIDC provider redirect callback.
func (o *OIDC) HandleCallback(c echo.Context) error {
tk, err := o.cfg.Exchange(c.Request().Context(), c.Request().URL.Query().Get("code"))
if err != nil { if err != nil {
panic(err) return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err))
} }
if config.Skipper == nil { rawIDTk, ok := tk.Extra("id_token").(string)
config.Skipper = middleware.DefaultSkipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if c.Request().URL.Path == pathURL.Path {
oauth2Token, err := oidcConfig.Exchange(c.Request().Context(), c.Request().URL.Query().Get("code"))
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to exchange token: %v", err))
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "No id_token field in oauth2 token") return echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.")
} }
idToken, err := verifier.Verify(c.Request().Context(), rawIDToken) idTk, err := o.verifier.Verify(c.Request().Context(), rawIDTk)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to verify ID Token: %v", err)) return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err))
} }
nonce, err := c.Cookie("nonce") nonce, err := c.Cookie("nonce")
@ -78,34 +76,39 @@ func OIDCAuth(config Config) echo.MiddlewareFunc {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("nonce cookie not found: %v", err)) return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("nonce cookie not found: %v", err))
} }
if idToken.Nonce != nonce.Value { if idTk.Nonce != nonce.Value {
return echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match") return echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match")
} }
c.SetCookie(&http.Cookie{ c.SetCookie(&http.Cookie{
Name: "id_token", Name: "id_token",
Value: rawIDToken, Value: rawIDTk,
Secure: true, Secure: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteStrictMode,
Path: "/", Path: "/",
}) })
// Login success - redirect back to the intended page
return c.Redirect(302, c.Request().URL.Query().Get("state")) return c.Redirect(302, c.Request().URL.Query().Get("state"))
} }
// check if request is authenticated func (o *OIDC) Middleware(next echo.HandlerFunc) echo.HandlerFunc {
rawIDToken, err := c.Cookie("id_token") return func(c echo.Context) error {
if err == nil { // cookie found if o.skipper != nil && o.skipper(c) {
_, err = verifier.Verify(c.Request().Context(), rawIDToken.Value)
if err == nil {
return next(c) return next(c)
} }
} else if err != http.ErrNoCookie {
rawIDTk, err := c.Cookie("id_token")
if err != http.ErrNoCookie {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
} }
// Redirect to login // Verify the token.
_, err = o.verifier.Verify(c.Request().Context(), rawIDTk.Value)
if err == nil {
return next(c)
}
// If the verification failed, redirect to the provider for auth.
nonce, err := randString(16) nonce, err := randString(16)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
@ -114,11 +117,10 @@ func OIDCAuth(config Config) echo.MiddlewareFunc {
Name: "nonce", Name: "nonce",
Value: nonce, Value: nonce,
Secure: true, Secure: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteStrictMode,
Path: "/", Path: "/",
}) })
return c.Redirect(302, oidcConfig.AuthCodeURL(c.Request().URL.RequestURI(), oidc.Nonce(nonce))) return c.Redirect(302, o.cfg.AuthCodeURL(c.Request().URL.RequestURI(), oidc.Nonce(nonce)))
}
} }
} }