diff --git a/oidc.go b/oidc.go index fb27354b..e569fe10 100644 --- a/oidc.go +++ b/oidc.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net/http" "regexp" @@ -15,6 +16,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" + "gorm.io/gorm" ) const ( @@ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error { h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) if err != nil { - log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) + log.Error(). + Err(err). + Caller(). + Msgf("Could not retrieve OIDC Config: %s", err.Error()) return err } @@ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error { // Puts machine key in cache so the callback can retrieve it using the oidc state param // Listens in /oidc/register/:mKey. func (h *Headscale) RegisterOIDC(ctx *gin.Context) { - mKeyStr := ctx.Param("mkey") - if mKeyStr == "" { + machineKeyStr := ctx.Param("mkey") + if machineKeyStr == "" { ctx.String(http.StatusBadRequest, "Wrong params") return @@ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { randomBlob := make([]byte, randomByteSize) if _, err := rand.Read(randomBlob); err != nil { - log.Error().Msg("could not read 16 bytes from rand") + log.Error(). + Caller(). + Msg("could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") return @@ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration) + h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) authURL := h.oauth2Config.AuthCodeURL(stateStr) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { - ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + log.Error(). + Err(err). + Caller(). + Msg("failed to verify id token") + ctx.String(http.StatusBadRequest, "Failed to verify id token") return } @@ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { // Extract custom claims var claims IDTokenClaims if err = idToken.Claims(&claims); err != nil { + log.Error(). + Err(err). + Caller(). + Msg("Failed to decode id token claims") ctx.String( http.StatusBadRequest, - fmt.Sprintf("Failed to decode id token claims: %s", err), + fmt.Sprintf("Failed to decode id token claims"), ) return } // retrieve machinekey from state cache - mKeyIf, mKeyFound := h.oidcStateCache.Get(state) + machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) - if !mKeyFound { + if !machineKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") ctx.String(http.StatusBadRequest, "state has expired") return } - mKeyStr, mKeyOK := mKeyIf.(string) + machineKey, machineKeyOK := machineKeyIf.(string) - if !mKeyOK { + if !machineKeyOK { log.Error().Msg("could not get machine key from cache") ctx.String( http.StatusInternalServerError, @@ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { } // retrieve machine information - machine, err := h.GetMachineByMachineKey(mKeyStr) + machine, err := h.GetMachineByMachineKey(machineKey) if err != nil { log.Error().Msg("machine key not found in database") ctx.String( @@ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { log.Debug().Msg("Registering new machine after successful callback") namespace, err := h.GetNamespace(namespaceName) - if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { namespace, err = h.CreateNamespace(namespaceName) if err != nil { log.Error(). - Msgf("could not create new namespace '%s'", claims.Email) + Err(err). + Caller(). + Msgf("could not create new namespace '%s'", namespaceName) ctx.String( http.StatusInternalServerError, "could not create new namespace", @@ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + } else if err != nil { + log.Error(). + Caller(). + Err(err). + Str("namespace", namespaceName). + Msg("could not find or create namespace") + ctx.String( + http.StatusInternalServerError, + "could not find or create namespace", + ) + + return } ip, err := h.getAvailableIP() if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not get an IP from the pool") ctx.String( http.StatusInternalServerError, "could not get an IP from the pool", @@ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { } log.Error(). + Caller(). Str("email", claims.Email). Str("username", claims.Username). Str("machine", machine.Name).