Compare commits

...

12 commits

Author SHA1 Message Date
Kristoffer Dalby 038aca1fe2
Merge 7c6f4e6f8a into f368ed01ed 2024-09-05 19:30:14 +02:00
Kristoffer Dalby 7c6f4e6f8a
add pr
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-05 14:19:58 +02:00
Kristoffer Dalby 2a3e21fee4
draft changelog
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-05 14:17:00 +02:00
Kristoffer Dalby d0adeb94d3
remove unused import
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 10:29:08 +02:00
Kristoffer Dalby e78aaa85ac
Update hscontrol/types/users.go
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2024-09-04 09:35:53 +02:00
Kristoffer Dalby 25a206d1c5
update nix hash
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:35:51 +02:00
Kristoffer Dalby b48b997733
start replacing User.Name with User.Username()
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:29 +02:00
Kristoffer Dalby fbe7faad1e
remove usernames in magic dns, normalisation of emails
this commit removes the option to have usernames as part of MagicDNS
domains and headscale will now align with Tailscale, where there is a
root domain, and the machine name.

In addition, the various normalisation functions for dns names has been
made lighter not caring about username and special character that wont
occur.

Email are no longer normalised as part of the policy processing.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:29 +02:00
Kristoffer Dalby b0b65042ec
expand user, add claims to user
This commit expands the user table with additional fields that
can be retrieved from OIDC providers (and other places) and
uses this data in various tailscale response objects if it is
available.

This is the beginning of implementing
https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit
trying to make OIDC more coherant and maintainable in addition
to giving the user a better experience and integration with a
provider.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:29 +02:00
Kristoffer Dalby acd1ed2b0c
remove unused state arg
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:29 +02:00
Kristoffer Dalby 4a8cbb84d5
remove unused machinekey arg
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:28 +02:00
Kristoffer Dalby 2d14d43f04
implement auth as provider interface, dry oidc
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-04 09:32:26 +02:00
27 changed files with 447 additions and 737 deletions

View file

@ -1,5 +1,18 @@
# CHANGELOG
## Next
### BREAKING
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
- Having usernames in magic DNS is no longer possible.
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
- User has been extended to store username, display name, profile picture url and email.
- These fields are forwarded to the client, and shows up nicely in the user switcher.
- These fields can be made available via the API/CLI for non-OIDC users in the future.
## 0.23.0 (2023-XX-XX)
This release is mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.

View file

@ -32,7 +32,7 @@
# When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files.
vendorHash = "sha256-+8dOxPG/Q+wuHgRwwWqdphHOuop0W9dVyClyQuh7aRc=";
vendorHash = "sha256-jDTOlRzVorbVT8fXwRedU/O2nf15cKABFrdxoTA+1wk=";
subPackages = ["cmd/headscale"];

View file

@ -18,7 +18,6 @@ import (
"syscall"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/davecgh/go-spew/spew"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -41,7 +40,6 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -95,11 +93,10 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
registrationCache *cache.Cache
authProvider AuthProvider
pollNetMapStreamWG sync.WaitGroup
}
@ -154,16 +151,31 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
})
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
err = app.initOIDC()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
cfg.ServerURL,
&cfg.OIDC,
app.db,
app.registrationCache,
app.nodeNotifier,
app.ipAlloc,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
return nil, err
} else {
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
}
} else {
authProvider = oidcProvider
}
}
app.authProvider = authProvider
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
// TODO(kradalby): revisit why this takes a list.
@ -429,10 +441,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallback).Methods(http.MethodGet)
}
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)

View file

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
@ -19,6 +18,11 @@ import (
"tailscale.com/types/ptr"
)
type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
}
func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
@ -175,7 +179,7 @@ func (h *Headscale) handleRegister(
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !regReq.Expiry.IsZero() &&
regReq.Expiry.UTC().Before(now) {
h.handleNodeLogOut(writer, *node, machineKey)
h.handleNodeLogOut(writer, *node)
return
}
@ -183,7 +187,7 @@ func (h *Headscale) handleRegister(
// If node is not expired, and it is register, we have a already accepted this node,
// let it proceed with a valid registration
if !node.IsExpired() {
h.handleNodeWithValidRegistration(writer, *node, machineKey)
h.handleNodeWithValidRegistration(writer, *node)
return
}
@ -196,7 +200,6 @@ func (h *Headscale) handleRegister(
writer,
regReq,
*node,
machineKey,
)
return
@ -209,7 +212,6 @@ func (h *Headscale) handleRegister(
writer,
regReq,
*node,
machineKey,
)
return
@ -410,7 +412,7 @@ func (h *Headscale) handleAuthKey(
}
}
h.db.Write(func(tx *gorm.DB) error {
err = h.db.Write(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak)
})
if err != nil {
@ -471,17 +473,7 @@ func (h *Headscale) handleNewNode(
// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
}
resp.AuthURL = h.authProvider.AuthURL(machineKey)
respBody, err := json.Marshal(resp)
if err != nil {
@ -504,7 +496,6 @@ func (h *Headscale) handleNewNode(
func (h *Headscale) handleNodeLogOut(
writer http.ResponseWriter,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -587,7 +578,6 @@ func (h *Headscale) handleNodeLogOut(
func (h *Headscale) handleNodeWithValidRegistration(
writer http.ResponseWriter,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -633,7 +623,6 @@ func (h *Headscale) handleNodeKeyRefresh(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -709,15 +698,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
}
resp.AuthURL = h.authProvider.AuthURL(machineKey)
respBody, err := json.Marshal(resp)
if err != nil {

View file

@ -256,9 +256,6 @@ func NewHeadscaleDatabase(
for item, node := range nodes {
if node.GivenName == "" {
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
node.Hostname,
)
if err != nil {
log.Error().
Caller().
@ -268,7 +265,7 @@ func NewHeadscaleDatabase(
}
err = tx.Model(nodes[item]).Updates(types.Node{
GivenName: normalizedHostname,
GivenName: node.Hostname,
}).Error
if err != nil {
log.Error().
@ -413,6 +410,18 @@ func NewHeadscaleDatabase(
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202407191627",
Migrate: func(tx *gorm.DB) error {
err := tx.AutoMigrate(&types.User{})
if err != nil {
return err
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View file

@ -337,7 +337,7 @@ func RegisterNodeFromAuthCallback(
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
@ -390,7 +390,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Str("user", node.User.Username()).
Msg("Registering node")
// If the node exists and it already has IP(s), we just save it
@ -406,7 +406,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Str("user", node.User.Username()).
Msg("Node authorized again")
return &node, nil
@ -617,18 +617,15 @@ func enableRoutes(tx *gorm.DB,
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
suppliedName,
)
if err != nil {
return "", err
if len(suppliedName) > util.LabelHostnameLength {
return "", types.ErrHostnameTooLong
}
if randomSuffix {
// Trim if a hostname will be longer than 63 chars after adding the hash.
trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
if len(normalizedHostname) > trimmedHostnameLength {
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
if len(suppliedName) > trimmedHostnameLength {
suppliedName = suppliedName[:trimmedHostnameLength]
}
suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
@ -636,10 +633,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return "", err
}
normalizedHostname += "-" + suffix
suppliedName += "-" + suffix
}
return normalizedHostname, nil
return suppliedName, nil
}
func (hsdb *HSDatabase) GenerateGivenName(

View file

@ -22,6 +22,7 @@ var (
)
func (hsdb *HSDatabase) CreatePreAuthKey(
// TODO(kradalby): Should be ID, not name
userName string,
reusable bool,
ephemeral bool,
@ -36,13 +37,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
// TODO(kradalby): Should be ID, not name
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, err
}
@ -104,7 +106,7 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, err
}

View file

@ -644,7 +644,7 @@ func EnableAutoApprovedRoutes(
Msg("looking up route for autoapproving")
for _, approvedAlias := range routeApprovers {
if approvedAlias == node.User.Name {
if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff

View file

@ -49,7 +49,7 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it.
func DestroyUser(tx *gorm.DB, name string) error {
user, err := GetUser(tx, name)
user, err := GetUserByName(tx, name)
if err != nil {
return ErrUserNotFound
}
@ -90,7 +90,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
// not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, oldName, newName string) error {
var err error
oldUser, err := GetUser(tx, oldName)
oldUser, err := GetUserByName(tx, oldName)
if err != nil {
return err
}
@ -98,7 +98,7 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
if err != nil {
return err
}
_, err = GetUser(tx, newName)
_, err = GetUserByName(tx, newName)
if err == nil {
return ErrUserExists
}
@ -115,13 +115,13 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
return nil
}
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUser(rx, name)
return GetUserByName(rx, name)
})
}
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
func GetUserByName(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error,
@ -133,6 +133,24 @@ func GetUser(tx *gorm.DB, name string) (*types.User, error) {
return &user, nil
}
func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByOIDCIdentifier(rx, id)
})
}
func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "provider_identifier = ?", id); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrUserNotFound
}
return &user, nil
}
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx)
@ -155,7 +173,7 @@ func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
if err != nil {
return nil, err
}
user, err := GetUser(tx, name)
user, err := GetUserByName(tx, name)
if err != nil {
return nil, err
}
@ -180,7 +198,7 @@ func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
if err != nil {
return err
}
user, err := GetUser(tx, username)
user, err := GetUserByName(tx, username)
if err != nil {
return err
}

View file

@ -20,7 +20,7 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
_, err = db.GetUserByName("test")
c.Assert(err, check.NotNil)
}
@ -73,10 +73,10 @@ func (s *Suite) TestRenameUser(c *check.C) {
err = db.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
_, err = db.GetUserByName("test")
c.Assert(err, check.Equals, ErrUserNotFound)
_, err = db.GetUser("test-renamed")
_, err = db.GetUserByName("test-renamed")
c.Assert(err, check.IsNil)
err = db.RenameUser("test-does-not-exit", "test")

View file

@ -41,7 +41,7 @@ func (api headscaleV1APIServer) GetUser(
ctx context.Context,
request *v1.GetUserRequest,
) (*v1.GetUserResponse, error) {
user, err := api.h.db.GetUser(request.GetName())
user, err := api.h.db.GetUserByName(request.GetName())
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err
}
user, err := api.h.db.GetUser(request.GetNewName())
user, err := api.h.db.GetUserByName(request.GetNewName())
if err != nil {
return nil, err
}
@ -774,7 +774,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
ctx context.Context,
request *v1.DebugCreateNodeRequest,
) (*v1.DebugCreateNodeResponse, error) {
user, err := api.h.db.GetUser(request.GetUser())
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}

View file

@ -8,6 +8,7 @@ import (
"html/template"
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
@ -167,12 +168,29 @@ var registerWebAPITemplate = template.Must(
</html>
`))
type AuthProviderWeb struct {
serverURL string
}
func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
return &AuthProviderWeb{
serverURL: serverURL,
}
}
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register/:nkey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.
func (h *Headscale) RegisterWebAPI(
func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -187,7 +205,7 @@ func (h *Headscale) RegisterWebAPI(
[]byte(machineKeyStr),
)
if err != nil {
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
log.Warn().Err(err).Msg("Failed to parse incoming machinekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)

View file

@ -15,7 +15,6 @@ import (
"sync/atomic"
"time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
@ -95,10 +94,10 @@ func generateUserProfiles(
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[string]types.User)
userMap[node.User.Name] = node.User
userMap := make(map[uint]types.User)
userMap[node.User.ID] = node.User
for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there
userMap[peer.User.ID] = peer.User // not worth checking if already is there
}
var profiles []tailcfg.UserProfile
@ -122,32 +121,6 @@ func generateDNSConfig(
dnsConfig := cfg.DNSConfig.Clone()
// if MagicDNS is enabled
if dnsConfig.Proxied {
if cfg.DNSUserNameInMagicDNS {
// Only inject the Search Domain of the current user
// shared nodes should use their full FQDN
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf(
"%s.%s",
node.User.Name,
baseDomain,
),
)
userSet := mapset.NewSet[types.User]()
userSet.Add(node.User)
for _, p := range peers {
userSet.Add(p.User)
}
for _, user := range userSet.ToSlice() {
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
dnsConfig.Routes[dnsRoute] = nil
}
}
}
addNextDNSMetadata(dnsConfig.Resolvers, node)
return dnsConfig

View file

@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
@ -28,6 +29,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Hostname: hostname,
UserID: userid,
User: types.User{
Model: gorm.Model{
ID: userid,
},
Name: username,
},
}
@ -72,14 +76,9 @@ func TestDNSConfigMapResponse(t *testing.T) {
{
magicDNS: true,
want: &tailcfg.DNSConfig{
Routes: map[string][]*dnstype.Resolver{
"shared1.foobar.headscale.net": {},
"shared2.foobar.headscale.net": {},
"shared3.foobar.headscale.net": {},
},
Routes: map[string][]*dnstype.Resolver{},
Domains: []string{
"foobar.headscale.net",
"shared1.foobar.headscale.net",
},
Proxied: true,
},
@ -127,8 +126,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
got := generateDNSConfig(
&types.Config{
DNSConfig: &dnsConfigOrig,
DNSUserNameInMagicDNS: true,
DNSConfig: &dnsConfigOrig,
},
baseDomain,
nodeInShared1,

View file

@ -76,7 +76,7 @@ func tailNode(
keyExpiry = time.Time{}
}
hostname, err := node.GetFQDN(cfg, cfg.BaseDomain)
hostname, err := node.GetFQDN(cfg.BaseDomain)
if err != nil {
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
}

View file

@ -17,8 +17,10 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
@ -45,49 +47,77 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *cache.Cache
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
}
func (h *Headscale) initOIDC() error {
func NewAuthProviderOIDC(
ctx context.Context,
serverURL string,
cfg *types.OIDCConfig,
db *db.HSDatabase,
registrationCache *cache.Cache,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
return fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: h.cfg.OIDC.Scope,
}
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
return nil
oauth2Config := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
}
return &AuthProviderOIDC{
serverURL: serverURL,
cfg: cfg,
db: db,
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if h.cfg.OIDC.UseExpiryFromToken {
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
func (a *AuthProviderOIDC) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if a.cfg.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(h.cfg.OIDC.Expiry)
return time.Now().Add(a.cfg.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
// Listens in /register/:mKey.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -108,46 +138,33 @@ func (h *Headscale) RegisterOIDC(
[]byte(machineKeyStr),
)
if err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
util.LogErr(err, "could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
a.registrationCache.Set(
stateStr,
machineKey,
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
for k, v := range h.cfg.OIDC.ExtraParams {
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
@ -170,79 +187,78 @@ var oidcCallbackTemplate = template.Must(
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(
func (a *AuthProviderOIDC) OIDCCallback(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
code, state, err := validateOIDCCallbackParams(req)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
rawIDToken, err := a.getIDTokenForOIDCCallback(req.Context(), code)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
idToken, err := a.verifyIDTokenForOIDCCallback(req.Context(), rawIDToken)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
idTokenExpiry := a.determineTokenExpiration(idToken.Expiry)
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, idToken)
if err != nil {
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
machineKey, nodeExists, err := a.validateNodeForOIDCCallback(
writer,
state,
claims,
&claims,
idTokenExpiry,
)
if err != nil || nodeExists {
return
}
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// register the node if it's new
log.Debug().Msg("Registering new node after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
if err := a.registerNodeForOIDCCallback(user, machineKey, idTokenExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
content, err := renderOIDCCallbackTemplate(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
@ -254,127 +270,57 @@ func (h *Headscale) OIDCCallback(
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
func (a *AuthProviderOIDC) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
code string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
return "", fmt.Errorf("could not exchange code for token: %w", err)
}
log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", errNoOIDCIDToken
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
func (a *AuthProviderOIDC) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
util.LogErr(err, "failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
return idToken, nil
}
func extractIDTokenClaims(
writer http.ResponseWriter,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
util.LogErr(err, "Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return &claims, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
log.Trace().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedDomains
}
}
@ -387,9 +333,8 @@ func validateOIDCAllowedDomains(
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
writer http.ResponseWriter,
allowedGroups []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
@ -398,14 +343,6 @@ func validateOIDCAllowedGroups(
}
}
log.Trace().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedGroups
}
@ -415,20 +352,12 @@ func validateOIDCAllowedGroups(
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedUsers
}
@ -439,40 +368,21 @@ func validateOIDCAllowedUsers(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
func (a *AuthProviderOIDC) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
expiry time.Time,
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
machineKeyIf, machineKeyFound := a.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
}
@ -480,7 +390,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByMachineKey(machineKey)
node, _ := a.db.GetNodeByMachineKey(machineKey)
if node != nil {
log.Trace().
@ -488,20 +398,13 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node.ID, expiry)
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
writer,
"Failed to refresh node",
http.StatusInternalServerError,
)
return nil, true, err
}
log.Debug().
Str("node", node.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Time("expiresAt", expiry).
Msg("successfully refreshed node")
var content bytes.Buffer
@ -509,13 +412,6 @@ func (h *Headscale) validateNodeForOIDCCallback(
User: claims.Email,
Verb: "Reauthenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, true, fmt.Errorf("rendering OIDC callback template: %w", err)
}
@ -527,7 +423,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
h.nodeNotifier.NotifyByNodeID(
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
@ -537,7 +433,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil, true, nil
}
@ -545,79 +441,56 @@ func (h *Headscale) validateNodeForOIDCCallback(
return &machineKey, false, nil
}
func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
)
if err != nil {
util.LogErr(err, "couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return userName, nil
}
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
) (*types.User, error) {
user, err := h.db.GetUser(userName)
if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName)
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, fmt.Errorf("creating new user: %w", err)
}
} else if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
// This check is for legacy, if the user cannot be found by the OIDC identifier
// look it up by username. This should only be needed once.
if user == nil {
user, err = a.db.GetUserByName(claims.Username)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return nil, fmt.Errorf("find or create user: %w", err)
// if the user is still not found, create a new empty user.
if user == nil {
user = &types.User{}
}
}
user.FromClaim(claims)
err = a.db.DB.Save(user).Error
if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return user, nil
}
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
func (a *AuthProviderOIDC) registerNodeForOIDCCallback(
user *types.User,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
ipv4, ipv6, err := h.ipAlloc.Next()
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
}
if err := h.db.Write(func(tx *gorm.DB) error {
if err := a.db.Write(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache,
a.registrationCache,
*machineKey,
// TODO(kradalby): Should be ID, not name
user.Name,
&expiry,
util.RegisterMethodOIDC,
@ -628,36 +501,20 @@ func (h *Headscale) registerNodeForOIDCCallback(
return nil
}); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register node"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return err
return fmt.Errorf("could not register node: %w", err)
}
return nil
}
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Authenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}

View file

@ -737,15 +737,7 @@ func (pol *ACLPolicy) expandUsersFromGroup(
ErrInvalidGroup,
)
}
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
if err != nil {
return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w",
group,
ErrInvalidGroup,
)
}
users = append(users, grp)
users = append(users, group)
}
return users, nil
@ -934,7 +926,7 @@ func (pol *ACLPolicy) TagsOfNode(
}
var found bool
for _, owner := range owners {
if node.User.Name == owner {
if node.User.Username() == owner {
found = true
}
}
@ -958,7 +950,7 @@ func (pol *ACLPolicy) TagsOfNode(
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
var out types.Nodes
for _, node := range nodes {
if node.User.Name == user {
if node.User.Username() == user {
out = append(out, node)
}
}

View file

@ -341,7 +341,7 @@ func TestParsing(t *testing.T) {
],
},
],
}
}
`,
want: []tailcfg.FilterRule{
{
@ -633,25 +633,6 @@ func Test_expandGroup(t *testing.T) {
want: []string{},
wantErr: true,
},
{
name: "Expand emails in group strip domains",
field: field{
pol: ACLPolicy{
Groups: Groups{
"group:admin": []string{
"joe.bar@gmail.com",
"john.doe@yahoo.fr",
},
},
},
},
args: args{
group: "group:admin",
stripEmail: true,
},
want: []string{"joe.bar", "john.doe"},
wantErr: false,
},
{
name: "Expand emails in group",
field: field{
@ -667,7 +648,7 @@ func Test_expandGroup(t *testing.T) {
args: args{
group: "group:admin",
},
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
want: []string{"joe.bar@gmail.com", "john.doe@yahoo.fr"},
wantErr: false,
},
}

View file

@ -46,9 +46,7 @@ func (s *Suite) ResetDB(c *check.C) {
Path: tmpDir + "/headscale_test.db",
},
},
OIDC: types.OIDCConfig{
StripEmaildomain: false,
},
OIDC: types.OIDCConfig{},
}
app, err = NewHeadscale(&cfg)

View file

@ -71,8 +71,7 @@ type Config struct {
ACMEURL string
ACMEEmail string
DNSConfig *tailcfg.DNSConfig
DNSUserNameInMagicDNS bool
DNSConfig *tailcfg.DNSConfig
UnixSocket string
UnixSocketPermission fs.FileMode
@ -90,12 +89,11 @@ type Config struct {
}
type DNSConfig struct {
MagicDNS bool `mapstructure:"magic_dns"`
BaseDomain string `mapstructure:"base_domain"`
Nameservers Nameservers
SearchDomains []string `mapstructure:"search_domains"`
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
UserNameInMagicDNS bool `mapstructure:"use_username_in_magic_dns"`
MagicDNS bool `mapstructure:"magic_dns"`
BaseDomain string `mapstructure:"base_domain"`
Nameservers Nameservers
SearchDomains []string `mapstructure:"search_domains"`
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
}
type Nameservers struct {
@ -164,7 +162,6 @@ type OIDCConfig struct {
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
StripEmaildomain bool
Expiry time.Duration
UseExpiryFromToken bool
}
@ -268,7 +265,6 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("database.sqlite.write_ahead_log", true)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
@ -315,8 +311,22 @@ func LoadConfig(path string, isFile bool) error {
depr.warn("dns_config.use_username_in_magic_dns")
depr.warn("dns.use_username_in_magic_dns")
depr.fatal("oidc.strip_email_domain")
depr.fatal("dns.use_username_in_musername_in_magic_dns")
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
depr.Log()
for _, removed := range []string{
"oidc.strip_email_domain",
"dns_config.use_username_in_musername_in_magic_dns",
} {
if viper.IsSet(removed) {
log.Fatal().
Msgf("Fatal config error: %s has been removed. Please remove it from your config file", removed)
}
}
// Collect any validation errors and return them all at once
var errorText string
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
@ -566,12 +576,9 @@ func DNS() (DNSConfig, error) {
if err != nil {
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
}
dns.ExtraRecords = extraRecords
}
dns.UserNameInMagicDNS = viper.GetBool("dns.use_username_in_magic_dns")
return dns, nil
}
@ -760,7 +767,12 @@ func GetHeadscaleConfig() (*Config, error) {
case string(IPAllocationStrategyRandom):
alloc = IPAllocationStrategyRandom
default:
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
return nil, fmt.Errorf(
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
allocStr,
IPAllocationStrategySequential,
IPAllocationStrategyRandom,
)
}
dnsConfig, err := DNS()
@ -794,10 +806,11 @@ func GetHeadscaleConfig() (*Config, error) {
// - DERP run on their own domains
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
//
// TODO(kradalby): remove dnsConfig.UserNameInMagicDNS check when removed.
if !dnsConfig.UserNameInMagicDNS && dnsConfig.BaseDomain != "" && strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New("server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.")
if dnsConfig.BaseDomain != "" &&
strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New(
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
)
}
return &Config{
@ -827,8 +840,7 @@ func GetHeadscaleConfig() (*Config, error) {
TLS: GetTLSConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
DNSConfig: DNSToTailcfgDNS(dnsConfig),
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
@ -840,15 +852,14 @@ func GetHeadscaleConfig() (*Config, error) {
OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available",
),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
Expiry: func() time.Duration {
// if set to 0, we assume no expiry
if value := viper.GetString("oidc.expiry"); value == "0" {
@ -883,9 +894,11 @@ func GetHeadscaleConfig() (*Config, error) {
// TODO(kradalby): Document these settings when more stable
Tuning: Tuning{
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
NodeMapSessionBufferedChanSize: viper.GetInt(
"tuning.node_mapsession_buffered_chan_size",
),
},
}, nil
}
@ -905,14 +918,26 @@ func (d *deprecator) warnWithAlias(newKey, oldKey string) {
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
viper.RegisterAlias(newKey, oldKey)
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.",
oldKey,
newKey,
oldKey,
),
)
}
}
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
func (d *deprecator) fatal(newKey, oldKey string) {
func (d *deprecator) fatal(oldKey string) {
if viper.IsSet(oldKey) {
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.fatals.Add(
fmt.Sprintf(
"The %q configuration key has been removed. Please see the changelog for more details.",
oldKey,
),
)
}
}
@ -920,7 +945,14 @@ func (d *deprecator) fatal(newKey, oldKey string) {
// If the new key is set, a warning is emitted instead.
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.fatals.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
oldKey,
newKey,
oldKey,
),
)
} else if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
}
@ -929,14 +961,26 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
// warn deprecates and adds an option to log a warning if the oldKey is set.
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
oldKey,
newKey,
oldKey,
),
)
}
}
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
func (d *deprecator) warn(oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated and has been removed. Please see the changelog for more details.",
oldKey,
),
)
}
}

View file

@ -40,8 +40,7 @@ func TestReadConfig(t *testing.T) {
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
{
@ -97,8 +96,7 @@ func TestReadConfig(t *testing.T) {
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
{
@ -232,11 +230,10 @@ func TestReadConfigFromEnv(t *testing.T) {
{
name: "unmarshal-dns-full-config",
configEnv: map[string]string{
"HEADSCALE_DNS_MAGIC_DNS": "true",
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
"HEADSCALE_DNS_USE_USERNAME_IN_MAGIC_DNS": "true",
"HEADSCALE_DNS_MAGIC_DNS": "true",
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
// TODO(kradalby): Figure out how to pass these as env vars
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
@ -264,8 +261,7 @@ func TestReadConfigFromEnv(t *testing.T) {
ExtraRecords: []tailcfg.DNSRecord{
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
}

View file

@ -393,7 +393,7 @@ func (node *Node) Proto() *v1.Node {
return nodeProto
}
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
func (node *Node) GetFQDN(baseDomain string) (string, error) {
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
@ -408,19 +408,6 @@ func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
)
}
if cfg.DNSUserNameInMagicDNS {
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
node.User.Name,
baseDomain,
)
}
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"failed to create valid FQDN (%s): %w",

View file

@ -127,76 +127,10 @@ func TestNodeFQDN(t *testing.T) {
tests := []struct {
name string
node Node
cfg Config
domain string
want string
wantErr string
}{
{
name: "all-set-with-username",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-given-name-with-username",
node: Node{
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name-with-username",
node: Node{
GivenName: "test",
User: User{},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node user has no name",
},
{
name: "no-magic-dns-with-username",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: false,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-dnsconfig-with-username",
node: Node{
@ -216,12 +150,6 @@ func TestNodeFQDN(t *testing.T) {
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
},
@ -232,46 +160,16 @@ func TestNodeFQDN(t *testing.T) {
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name",
name: "too-long-username",
node: Node{
GivenName: "test",
User: User{},
GivenName: "useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111",
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
},
{
name: "no-magic-dns",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: false,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
domain: "example.com",
wantErr: "failed to create valid FQDN (useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111.example.com): hostname too long, cannot except 255 ASCII chars",
},
{
name: "no-dnsconfig",
@ -288,7 +186,9 @@ func TestNodeFQDN(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.node.GetFQDN(&tc.cfg, tc.domain)
got, err := tc.node.GetFQDN(tc.domain)
t.Logf("GOT: %q, %q", got, tc.domain)
if (err != nil) && (err.Error() != tc.wantErr) {
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)

View file

@ -1,6 +1,7 @@
package types
import (
"cmp"
"strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -16,19 +17,57 @@ import (
// that contain our machines.
type User struct {
gorm.Model
// Username for the user, is used if email is empty
// Should not be used, please use Username().
Name string `gorm:"unique"`
// Typically a full name of the user
DisplayName string
// Email of the user
// Should not be used, please use Username().
Email string
// Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token
// and is used to lookup the user.
ProviderIdentifier string `gorm:"index"`
// Provider is the origin of the user account,
// same as RegistrationMethod, without authkey.
Provider string
ProfilePicURL string
}
// Username is the main way to get the username of a user,
// it will return the email if it exists, the name if it exists,
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
// Email and OIDCIdentifier will be set when the user has headscale
// enabled with OIDC, which means that there is a domain involved which
// should be used throughout headscale, in information returned to the
// user and the Policy engine.
func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
}
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
// it will return the Username.
func (u *User) DisplayNameOrUsername() string {
return cmp.Or(u.DisplayName, u.Username())
}
// TODO(kradalby): See if we can fill in Gravatar here
func (u *User) profilePicURL() string {
return ""
return u.ProfilePicURL
}
func (u *User) TailscaleUser() *tailcfg.User {
user := tailcfg.User{
ID: tailcfg.UserID(u.ID),
LoginName: u.Name,
DisplayName: u.Name,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
Logins: []tailcfg.LoginID{},
Created: u.CreatedAt,
@ -41,9 +80,9 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
ID: tailcfg.LoginID(u.ID),
// TODO(kradalby): this should reflect registration method.
Provider: "",
LoginName: u.Name,
DisplayName: u.Name,
Provider: u.Provider,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
}
@ -53,8 +92,8 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
return tailcfg.UserProfile{
ID: tailcfg.UserID(u.ID),
LoginName: u.Name,
DisplayName: u.Name,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
}
}
@ -66,3 +105,27 @@ func (n *User) Proto() *v1.User {
CreatedAt: timestamppb.New(n.CreatedAt),
}
}
type OIDCClaims struct {
// Sub is the user's unique identifier at the provider.
Sub string `json:"sub"`
// Name is the user's full name.
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
ProfilePictureURL string `json:"picture,omitempty"`
Username string `json:"preferred_username,omitempty"`
}
// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Sub
u.DisplayName = claims.Username
u.Email = claims.Email
u.Name = claims.Username
u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC
}

View file

@ -7,7 +7,6 @@ import (
"regexp"
"strings"
"github.com/spf13/viper"
"go4.org/netipx"
"tailscale.com/util/dnsname"
)
@ -25,38 +24,6 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidUserName = errors.New("invalid user name")
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
strip := viper.GetBool("oidc.strip_email_domain")
return NormalizeToFQDNRules(name, strip)
}
// NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > LabelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidUserName,
)
}
}
return name, nil
}
func CheckForFQDNRules(name string) error {
if len(name) > LabelHostnameLength {
return fmt.Errorf(

View file

@ -7,100 +7,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestNormalizeToFQDNRules(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "user name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "user with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeToFQDNRules() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckForFQDNRules(t *testing.T) {
type args struct {
name string

View file

@ -62,7 +62,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
}
err = scenario.CreateHeadscaleEnv(
@ -121,7 +120,6 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
}
@ -276,7 +274,6 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
),
ClientID: "superclient",
ClientSecret: "supersecret",
StripEmaildomain: true,
OnlyStartIfOIDCIsAvailable: true,
}, nil
}