Ensure we always have the key prefix when needed

This commit is contained in:
Kristoffer Dalby 2021-11-27 20:25:12 +00:00
parent c38f00fab8
commit 59aeaa8476
6 changed files with 38 additions and 8 deletions

2
api.go
View file

@ -75,7 +75,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
machineKeyStr := ctx.Param("id") machineKeyStr := ctx.Param("id")
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr)) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View file

@ -486,7 +486,9 @@ func nodesToPtables(
} }
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) err := nodeKey.UnmarshalText(
[]byte(headscale.NodePublicKeyEnsurePrefix(machine.NodeKey)),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -439,7 +439,7 @@ func (machine Machine) toNode(
includeRoutes bool, includeRoutes bool,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
var nodeKey key.NodePublic var nodeKey key.NodePublic
err := nodeKey.UnmarshalText([]byte(machine.NodeKey)) err := nodeKey.UnmarshalText([]byte(NodePublicKeyEnsurePrefix(machine.NodeKey)))
if err != nil { if err != nil {
log.Trace(). log.Trace().
Caller(). Caller().
@ -450,14 +450,18 @@ func (machine Machine) toNode(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machine.MachineKey)) err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err) return nil, fmt.Errorf("failed to parse machine public key: %w", err)
} }
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
if machine.DiscoKey != "" { if machine.DiscoKey != "" {
err := discoKey.UnmarshalText([]byte(discoPublicHexPrefix + machine.DiscoKey)) err := discoKey.UnmarshalText(
[]byte(DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse disco public key: %w", err) return nil, fmt.Errorf("failed to parse disco public key: %w", err)
} }
@ -634,7 +638,7 @@ func (h *Headscale) RegisterMachine(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr)) err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -192,7 +192,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machineKeyStr, machineKeyOK := machineKeyIf.(string) machineKeyStr, machineKeyOK := machineKeyIf.(string)
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(machineKeyStr)) err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse machine public key")

View file

@ -38,7 +38,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
machineKeyStr := ctx.Param("id") machineKeyStr := ctx.Param("id")
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(machineKeyStr)) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").

View file

@ -60,6 +60,30 @@ func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix) return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
} }
func MachinePublicKeyEnsurePrefix(machineKey string) string {
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
return machinePublicHexPrefix + machineKey
}
return machineKey
}
func NodePublicKeyEnsurePrefix(nodeKey string) string {
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
return nodePublicHexPrefix + nodeKey
}
return nodeKey
}
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
return discoPublicHexPrefix + discoKey
}
return discoKey
}
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string type Error string