mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-10 17:12:33 +08:00
Clean up logging and error handling in oidc
We should never expose errors via web, it gives attackers a lot of info (Insert OWASP guide). Also handle error that didnt separate not found gorm issue and other errors.
This commit is contained in:
parent
fac33e46e1
commit
fcd4d94927
1 changed files with 48 additions and 14 deletions
62
oidc.go
62
oidc.go
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
@ -15,6 +16,7 @@ import (
|
|||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error {
|
|||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
log.Error().
|
||||
Err(err).
|
||||
Caller().
|
||||
Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error {
|
|||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey.
|
||||
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||
mKeyStr := ctx.Param("mkey")
|
||||
if mKeyStr == "" {
|
||||
machineKeyStr := ctx.Param("mkey")
|
||||
if machineKeyStr == "" {
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
|
@ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
|||
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
if _, err := rand.Read(randomBlob); err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
log.Error().
|
||||
Caller().
|
||||
Msg("could not read 16 bytes from rand")
|
||||
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
|
||||
return
|
||||
|
@ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
|||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the machine key into the state cache, so it can be retrieved later
|
||||
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
|
||||
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration)
|
||||
|
||||
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
@ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
||||
log.Error().
|
||||
Err(err).
|
||||
Caller().
|
||||
Msg("failed to verify id token")
|
||||
ctx.String(http.StatusBadRequest, "Failed to verify id token")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
// Extract custom claims
|
||||
var claims IDTokenClaims
|
||||
if err = idToken.Claims(&claims); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Caller().
|
||||
Msg("Failed to decode id token claims")
|
||||
ctx.String(
|
||||
http.StatusBadRequest,
|
||||
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
||||
fmt.Sprintf("Failed to decode id token claims"),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// retrieve machinekey from state cache
|
||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
if !machineKeyFound {
|
||||
log.Error().
|
||||
Msg("requested machine state key expired before authorisation completed")
|
||||
ctx.String(http.StatusBadRequest, "state has expired")
|
||||
|
||||
return
|
||||
}
|
||||
mKeyStr, mKeyOK := mKeyIf.(string)
|
||||
machineKey, machineKeyOK := machineKeyIf.(string)
|
||||
|
||||
if !mKeyOK {
|
||||
if !machineKeyOK {
|
||||
log.Error().Msg("could not get machine key from cache")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
|
@ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
}
|
||||
|
||||
// retrieve machine information
|
||||
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
ctx.String(
|
||||
|
@ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
namespace, err = h.CreateNamespace(namespaceName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("could not create new namespace '%s'", claims.Email)
|
||||
Err(err).
|
||||
Caller().
|
||||
Msgf("could not create new namespace '%s'", namespaceName)
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not create new namespace",
|
||||
|
@ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("namespace", namespaceName).
|
||||
Msg("could not find or create namespace")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not find or create namespace",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("could not get an IP from the pool")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get an IP from the pool",
|
||||
|
@ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
}
|
||||
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", machine.Name).
|
||||
|
|
Loading…
Reference in a new issue