diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 72fefac3..58aa0e77 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -133,19 +133,25 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] - // place the node key into the state cache, so it can be retrieved later + // generate PKCE code verifier + verifier := oauth2.GenerateVerifier() + // place the node key and verifier into the state cache, so it can be retrieved later h.registrationCache.Set( stateStr, - machineKey, + types.RegistrationInfo{ + MachineKey: machineKey, + Verifier: verifier, + }, registerCacheExpiration, ) // Add any extra parameter provided in the configuration to the Authorize Endpoint request - extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) + extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)+2) for k, v := range h.cfg.OIDC.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier)) authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -179,7 +185,33 @@ func (h *Headscale) OIDCCallback( return } - rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state) + regState, stateFound := h.registrationCache.Get(state) + if !stateFound { + log.Trace(). + Msg("requested state key expired before authorisation completed") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("state has expired")) + if err != nil { + util.LogErr(err, "Failed to write response") + } + return + } + regInfo, regInfoOK := regState.(types.RegistrationInfo) + if !regInfoOK { + log.Trace(). + Interface("got", regInfo). + Msg("requested state is not a RegistrationInfo") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("state is invalid")) + if err != nil { + util.LogErr(err, "Failed to write response") + } + return + } + + rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state, regInfo) if err != nil { return } @@ -216,7 +248,7 @@ func (h *Headscale) OIDCCallback( machineKey, nodeExists, err := h.validateNodeForOIDCCallback( writer, - state, + regInfo, claims, idTokenExpiry, ) @@ -278,8 +310,9 @@ func (h *Headscale) getIDTokenForOIDCCallback( ctx context.Context, writer http.ResponseWriter, code, state string, + regInfo types.RegistrationInfo, ) (string, error) { - oauth2Token, err := h.oauth2Config.Exchange(ctx, code) + oauth2Token, err := h.oauth2Config.Exchange(ctx, code, oauth2.VerifierOption(regInfo.Verifier)) if err != nil { util.LogErr(err, "Could not exchange code for token") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -441,46 +474,17 @@ func validateOIDCAllowedUsers( // on to registration. func (h *Headscale) validateNodeForOIDCCallback( writer http.ResponseWriter, - state string, + state types.RegistrationInfo, claims *IDTokenClaims, expiry time.Time, ) (*key.MachinePublic, bool, error) { // retrieve nodekey from state cache - machineKeyIf, machineKeyFound := h.registrationCache.Get(state) - if !machineKeyFound { - log.Trace(). - Msg("requested node state key expired before authorisation completed") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("state has expired")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, false, errOIDCNodeKeyMissing - } - - var machineKey key.MachinePublic - machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic) - if !machineKeyOK { - log.Trace(). - Interface("got", machineKeyIf). - Msg("requested node state key is not a nodekey") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("state is invalid")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, false, errOIDCInvalidNodeState - } // retrieve node information if it exist // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. - node, _ := h.db.GetNodeByMachineKey(machineKey) + node, _ := h.db.GetNodeByMachineKey(state.MachineKey) if node != nil { log.Trace(). @@ -542,7 +546,7 @@ func (h *Headscale) validateNodeForOIDCCallback( return nil, true, nil } - return &machineKey, false, nil + return &state.MachineKey, false, nil } func getUserName( diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 04ca9f8d..b19020e0 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -558,3 +558,9 @@ func (nodes Nodes) IDMap() map[NodeID]*Node { return ret } + +// RegistrationInfo contains both machine key and verifier information for OIDC validation. +type RegistrationInfo struct { + MachineKey key.MachinePublic + Verifier string +}