diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 847d01c6..065203c4 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "net/url" "github.com/coreos/go-oidc/v3/oidc" "github.com/labstack/echo/v4" @@ -25,100 +24,103 @@ type Config struct { Skipper middleware.Skipper } -func OIDCAuth(config Config) echo.MiddlewareFunc { - provider, err := oidc.NewProvider(context.Background(), config.ProviderURL) +type OIDC struct { + cfg oauth2.Config + verifier *oidc.IDTokenVerifier + skipper middleware.Skipper +} + +func New(cfg Config) *OIDC { + provider, err := oidc.NewProvider(context.Background(), cfg.ProviderURL) if err != nil { panic(err) } verifier := provider.Verifier(&oidc.Config{ - ClientID: config.ClientID, + ClientID: cfg.ClientID, }) oidcConfig := oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, Endpoint: provider.Endpoint(), - RedirectURL: config.RedirectURL, + RedirectURL: cfg.RedirectURL, Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } - pathURL, err := url.Parse(config.RedirectURL) + return &OIDC{ + verifier: verifier, + cfg: oidcConfig, + skipper: cfg.Skipper, + } +} + +// HandleCallback is the HTTP handler that handles the post-OIDC provider redirect callback. +func (o *OIDC) HandleCallback(c echo.Context) error { + tk, err := o.cfg.Exchange(c.Request().Context(), c.Request().URL.Query().Get("code")) if err != nil { - panic(err) + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error exchanging token: %v", err)) } - if config.Skipper == nil { - config.Skipper = middleware.DefaultSkipper + rawIDTk, ok := tk.Extra("id_token").(string) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "`id_token` missing.") } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } + idTk, err := o.verifier.Verify(c.Request().Context(), rawIDTk) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("error verifying ID token: %v", err)) + } - if c.Request().URL.Path == pathURL.Path { - oauth2Token, err := oidcConfig.Exchange(c.Request().Context(), c.Request().URL.Query().Get("code")) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to exchange token: %v", err)) - } + nonce, err := c.Cookie("nonce") + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("nonce cookie not found: %v", err)) + } - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "No id_token field in oauth2 token") - } + if idTk.Nonce != nonce.Value { + return echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match") + } - idToken, err := verifier.Verify(c.Request().Context(), rawIDToken) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to verify ID Token: %v", err)) - } + c.SetCookie(&http.Cookie{ + Name: "id_token", + Value: rawIDTk, + Secure: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + }) - nonce, err := c.Cookie("nonce") - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("nonce cookie not found: %v", err)) - } + return c.Redirect(302, c.Request().URL.Query().Get("state")) +} - if idToken.Nonce != nonce.Value { - return echo.NewHTTPError(http.StatusUnauthorized, "nonce did not match") - } - - c.SetCookie(&http.Cookie{ - Name: "id_token", - Value: rawIDToken, - Secure: true, - SameSite: http.SameSiteLaxMode, - Path: "/", - }) - - // Login success - redirect back to the intended page - return c.Redirect(302, c.Request().URL.Query().Get("state")) - } - - // check if request is authenticated - rawIDToken, err := c.Cookie("id_token") - if err == nil { // cookie found - _, err = verifier.Verify(c.Request().Context(), rawIDToken.Value) - if err == nil { - return next(c) - } - } else if err != http.ErrNoCookie { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - - // Redirect to login - nonce, err := randString(16) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - c.SetCookie(&http.Cookie{ - Name: "nonce", - Value: nonce, - Secure: true, - SameSite: http.SameSiteLaxMode, - Path: "/", - }) - return c.Redirect(302, oidcConfig.AuthCodeURL(c.Request().URL.RequestURI(), oidc.Nonce(nonce))) +func (o *OIDC) Middleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if o.skipper != nil && o.skipper(c) { + return next(c) } + + rawIDTk, err := c.Cookie("id_token") + if err != http.ErrNoCookie { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + + // Verify the token. + _, err = o.verifier.Verify(c.Request().Context(), rawIDTk.Value) + if err == nil { + return next(c) + } + + // If the verification failed, redirect to the provider for auth. + nonce, err := randString(16) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + c.SetCookie(&http.Cookie{ + Name: "nonce", + Value: nonce, + Secure: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + }) + return c.Redirect(302, o.cfg.AuthCodeURL(c.Request().URL.RequestURI(), oidc.Nonce(nonce))) } }