This commit is contained in:
Rorical 2024-09-15 09:24:25 +00:00 committed by GitHub
commit 515f584254
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 38 deletions

View file

@ -133,19 +133,25 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
// generate PKCE code verifier
verifier := oauth2.GenerateVerifier()
// place the node key and verifier into the state cache, so it can be retrieved later
h.registrationCache.Set(
stateStr,
machineKey,
types.RegistrationInfo{
MachineKey: machineKey,
Verifier: verifier,
},
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)+2)
for k, v := range h.cfg.OIDC.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
extras = append(extras, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -179,7 +185,33 @@ func (h *Headscale) OIDCCallback(
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
regState, stateFound := h.registrationCache.Get(state)
if !stateFound {
log.Trace().
Msg("requested state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
regInfo, regInfoOK := regState.(types.RegistrationInfo)
if !regInfoOK {
log.Trace().
Interface("got", regInfo).
Msg("requested state is not a RegistrationInfo")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state, regInfo)
if err != nil {
return
}
@ -216,7 +248,7 @@ func (h *Headscale) OIDCCallback(
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
regInfo,
claims,
idTokenExpiry,
)
@ -278,8 +310,9 @@ func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
regInfo types.RegistrationInfo,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
oauth2Token, err := h.oauth2Config.Exchange(ctx, code, oauth2.VerifierOption(regInfo.Verifier))
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -441,46 +474,17 @@ func validateOIDCAllowedUsers(
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
state types.RegistrationInfo,
claims *IDTokenClaims,
expiry time.Time,
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByMachineKey(machineKey)
node, _ := h.db.GetNodeByMachineKey(state.MachineKey)
if node != nil {
log.Trace().
@ -542,7 +546,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, true, nil
}
return &machineKey, false, nil
return &state.MachineKey, false, nil
}
func getUserName(

View file

@ -558,3 +558,9 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
return ret
}
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier string
}