mirror of
https://github.com/juanfont/headscale.git
synced 2024-09-20 07:16:35 +08:00
Compare commits
3 commits
bfaa245274
...
515f584254
Author | SHA1 | Date | |
---|---|---|---|
515f584254 | |||
64fd5f484c | |||
38c148745a |
|
@ -133,19 +133,25 @@ func (h *Headscale) RegisterOIDC(
|
|||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
// generate PKCE code verifier
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
// place the node key and verifier into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(
|
||||
stateStr,
|
||||
machineKey,
|
||||
types.RegistrationInfo{
|
||||
MachineKey: machineKey,
|
||||
Verifier: verifier,
|
||||
},
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)+2)
|
||||
|
||||
for k, v := range h.cfg.OIDC.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
extras = append(extras, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
|
||||
|
||||
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
@ -179,7 +185,33 @@ func (h *Headscale) OIDCCallback(
|
|||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
|
||||
regState, stateFound := h.registrationCache.Get(state)
|
||||
if !stateFound {
|
||||
log.Trace().
|
||||
Msg("requested state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state has expired"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
return
|
||||
}
|
||||
regInfo, regInfoOK := regState.(types.RegistrationInfo)
|
||||
if !regInfoOK {
|
||||
log.Trace().
|
||||
Interface("got", regInfo).
|
||||
Msg("requested state is not a RegistrationInfo")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state, regInfo)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -216,7 +248,7 @@ func (h *Headscale) OIDCCallback(
|
|||
|
||||
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
||||
writer,
|
||||
state,
|
||||
regInfo,
|
||||
claims,
|
||||
idTokenExpiry,
|
||||
)
|
||||
|
@ -278,8 +310,9 @@ func (h *Headscale) getIDTokenForOIDCCallback(
|
|||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
code, state string,
|
||||
regInfo types.RegistrationInfo,
|
||||
) (string, error) {
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code, oauth2.VerifierOption(regInfo.Verifier))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Could not exchange code for token")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -441,46 +474,17 @@ func validateOIDCAllowedUsers(
|
|||
// on to registration.
|
||||
func (h *Headscale) validateNodeForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
state string,
|
||||
state types.RegistrationInfo,
|
||||
claims *IDTokenClaims,
|
||||
expiry time.Time,
|
||||
) (*key.MachinePublic, bool, error) {
|
||||
// retrieve nodekey from state cache
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
if !machineKeyFound {
|
||||
log.Trace().
|
||||
Msg("requested node state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state has expired"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
|
||||
if !machineKeyOK {
|
||||
log.Trace().
|
||||
Interface("got", machineKeyIf).
|
||||
Msg("requested node state key is not a nodekey")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCInvalidNodeState
|
||||
}
|
||||
|
||||
// retrieve node information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := h.db.GetNodeByMachineKey(machineKey)
|
||||
node, _ := h.db.GetNodeByMachineKey(state.MachineKey)
|
||||
|
||||
if node != nil {
|
||||
log.Trace().
|
||||
|
@ -542,7 +546,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
return nil, true, nil
|
||||
}
|
||||
|
||||
return &machineKey, false, nil
|
||||
return &state.MachineKey, false, nil
|
||||
}
|
||||
|
||||
func getUserName(
|
||||
|
|
|
@ -558,3 +558,9 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
|
|||
|
||||
return ret
|
||||
}
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
MachineKey key.MachinePublic
|
||||
Verifier string
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue