Compare commits

...

28 commits

Author SHA1 Message Date
Kristoffer Dalby b86db9a733
Merge e66d149cee into 10a72e8d54 2024-09-18 09:53:58 +00:00
Kristoffer Dalby e66d149cee
add pr
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:51 +01:00
Kristoffer Dalby ba0f844d5e
draft changelog
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:49 +01:00
Kristoffer Dalby fb5d40f71b
remove unused import
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:08 +01:00
Kristoffer Dalby 4850bb0c1c
Update hscontrol/types/users.go
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2024-09-18 10:53:08 +01:00
Kristoffer Dalby dc75a4f7bc
update nix hash
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:08 +01:00
Kristoffer Dalby d2efc63ca8
start replacing User.Name with User.Username()
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:08 +01:00
Kristoffer Dalby e9b95d2278
remove usernames in magic dns, normalisation of emails
this commit removes the option to have usernames as part of MagicDNS
domains and headscale will now align with Tailscale, where there is a
root domain, and the machine name.

In addition, the various normalisation functions for dns names has been
made lighter not caring about username and special character that wont
occur.

Email are no longer normalised as part of the policy processing.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:53:07 +01:00
Kristoffer Dalby 3151c629dc
expand user, add claims to user
This commit expands the user table with additional fields that
can be retrieved from OIDC providers (and other places) and
uses this data in various tailscale response objects if it is
available.

This is the beginning of implementing
https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit
trying to make OIDC more coherant and maintainable in addition
to giving the user a better experience and integration with a
provider.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:51:19 +01:00
Kristoffer Dalby be2c00d4f8
remove unused state arg
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:51:19 +01:00
Kristoffer Dalby 8e07f09f3b
remove unused machinekey arg
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:51:19 +01:00
Kristoffer Dalby 14da7c436a
implement auth as provider interface, dry oidc
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 10:51:19 +01:00
Kristoffer Dalby 10a72e8d54
update changelog for 0.23 release (#2138)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-18 09:43:08 +01:00
Kristoffer Dalby ed78ecda12
add shutdown that asserts if headscale had panics (#2126)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-17 11:44:55 +02:00
github-actions[bot] 6cbbcd859c
flake.lock: Update (#2135) 2024-09-16 06:15:45 +00:00
nblock e9d9c0773c
Exclude irrelevant files from mkdocs rendering (#2136) 2024-09-16 06:13:45 +00:00
nblock fe68f50328
Use headscale.example.com (#2122) 2024-09-11 16:46:06 +00:00
nblock c3ef90a7f7
Update documentation for Apple (#2117)
* Rename docs/ios-client.md to docs/apple-client.md. Add instructions
  for macOS; those are copied from the /apple endpoint and slightly
  modified. Fix doc links in the README.
* Move infoboxes for /apple and /windows under the "Goal" section to the
  top. Those should be seen by users first as they contain *their*
  specific headscale URL.
* Swap order of macOS and iOS to move "Profiles" further down.
* Remove apple configuration profiles
* Remove Tailscale versions hints
* Mention /apple and /windows in the README along with their docs

See: #2096
2024-09-11 18:43:59 +02:00
Kristoffer Dalby 064c46f2a5
move logic for validating node names (#2127)
* move logic for validating node names

this commits moves the generation of "given names" of nodes
into the registration function, and adds validation of renames
to RenameNode using the same logic.

Fixes #2121

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix double arg

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-11 18:27:49 +02:00
Kristoffer Dalby 64319f79ff
make stream shutdown if self-node has been removed (#2125)
* add shutdown that asserts if headscale had panics

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add test case producing 2118 panic

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* make stream shutdown if self-node has been removed

Currently we will read the node from database, and since it is
deleted, the id might be set to nil. Keep the node around and
just shutdown, so it is cleanly removed from notifier.

Fixes #2118

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-11 12:00:32 +02:00
Kristoffer Dalby 4b02dc9565
make cli mode respect log.level (#2124)
Fixes #2119

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-11 10:43:22 +02:00
Kristoffer Dalby 7be8796d87
dont override golangci go (#2116)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-09 14:29:09 +02:00
curlwget 99f18f9cd9
chore: fix some comments (#2069) 2024-09-09 14:17:25 +02:00
github-actions[bot] c3b260a6f7
flake.lock: Update (#2111) 2024-09-09 14:16:35 +02:00
Kristoffer Dalby 60b94b0467
Fix slow shutdown (#2113)
* rearrange shutdown

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* http closed is fine

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* update changelog

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* logging while shutting

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-09-09 14:10:22 +02:00
nblock bac7ea67f4
Simplify windows setup instructions (#2114)
* Simplify /windows to the bare minimum. Also remove the
  /windows/tailscale.reg endpoint as its generated file is no longer
  valid for current Tailscale versions.
* Update and simplify the windows documentation accordingly.
* Add a "Unattended mode" section to the troubleshooting section
  explaining how to enable "Unattended mode" in the via the Tailscale
  tray icon.
* Add infobox about /windows to the docs

Tested on Windows 10, 22H2 with Tailscale 1.72.0

Replaces: #1995
See: #2096
2024-09-09 13:18:16 +02:00
nblock 5597edac1e
Remove version and update setup instructions for Android (#2112) 2024-09-09 06:57:50 +00:00
Kristoffer Dalby 8a3a0fee3c
Only load needed part of configuration (#2109) 2024-09-07 09:23:58 +02:00
70 changed files with 1214 additions and 1564 deletions

View file

@ -52,6 +52,7 @@ jobs:
- TestExpireNode
- TestNodeOnlineStatus
- TestPingAllByIPManyUpDown
- Test2118DeletingOnlineNodePanics
- TestEnablingRoutes
- TestHASubnetRouterFailover
- TestEnableDisableAutoApprovedRoute

1
.gitignore vendored
View file

@ -22,6 +22,7 @@ dist/
/headscale
config.json
config.yaml
config*.yaml
derp.yaml
*.hujson
*.key

View file

@ -1,8 +1,22 @@
# CHANGELOG
## 0.23.0 (2023-XX-XX)
## Next
This release is mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
### BREAKING
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
- Having usernames in magic DNS is no longer possible.
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
- User has been extended to store username, display name, profile picture url and email.
- These fields are forwarded to the client, and shows up nicely in the user switcher.
- These fields can be made available via the API/CLI for non-OIDC users in the future.
## 0.23.0 (2023-09-18)
This release was intended to be mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
However, as you all have noticed, it turned out to become a much larger, much longer release cycle than anticipated. It has ended up to be a release with a lot of rewrites and changes to the code base and functionality of Headscale, cleaning up a lot of technical debt and introducing a lot of improvements. This does come with some breaking changes,
**Please remember to always back up your database between versions**
@ -16,7 +30,7 @@ The [“poller”, or streaming logic](https://github.com/juanfont/headscale/blo
Headscale now supports sending “delta” updates, thanks to the new mapper and poller logic, allowing us to only inform nodes about new nodes, changed nodes and removed nodes. Previously we sent the entire state of the network every time an update was due.
While we have a pretty good [test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) for validating our changes, we have rewritten over [10000 lines of code](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...main) and bugs are expected. We need help testing this release. In addition, while we think the performance should in general be better, there might be regressions in parts of the platform, particularly where we prioritised correctness over speed.
While we have a pretty good [test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) for validating our changes, the changes came down to [284 changed files with 32,316 additions and 24,245 deletions](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...ed78ecd) and bugs are expected. We need help testing this release. In addition, while we think the performance should in general be better, there might be regressions in parts of the platform, particularly where we prioritised correctness over speed.
There are also several bugfixes that has been encountered and fixed as part of implementing these changes, particularly
after improving the test harness as part of adopting [#1460](https://github.com/juanfont/headscale/pull/1460).
@ -72,6 +86,9 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792)
- Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764)
- Make sure integration tests cover postgres for all scenarios
- CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109)
- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109)
- Fix issue where shutting down headscale would hang [#2113](https://github.com/juanfont/headscale/pull/2113)
## 0.22.3 (2023-05-12)

View file

@ -62,15 +62,15 @@ buttons available in the repo.
## Client OS support
| OS | Supports headscale |
| ------- | --------------------------------------------------------- |
| Linux | Yes |
| OpenBSD | Yes |
| FreeBSD | Yes |
| macOS | Yes (see `/apple` on your headscale for more information) |
| Windows | Yes [docs](./docs/windows-client.md) |
| Android | Yes [docs](./docs/android-client.md) |
| iOS | Yes [docs](./docs/iOS-client.md) |
| OS | Supports headscale |
| ------- | -------------------------------------------------------------------------------------------------- |
| Linux | Yes |
| OpenBSD | Yes |
| FreeBSD | Yes |
| Windows | Yes (see [docs](./docs/windows-client.md) and `/windows` on your headscale for more information) |
| Android | Yes (see [docs](./docs/android-client.md)) |
| macOS | Yes (see [docs](./docs/apple-client.md#macos) and `/apple` on your headscale for more information) |
| iOS | Yes (see [docs](./docs/apple-client.md#ios) and `/apple` on your headscale for more information) |
## Running headscale

View file

@ -54,7 +54,7 @@ var listAPIKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -67,14 +67,10 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetApiKeys(), "", output)
return
}
tableData := pterm.TableData{
@ -102,8 +98,6 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -119,9 +113,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
log.Trace().
Msg("Preparing to create ApiKey")
request := &v1.CreateApiKeyRequest{}
durationStr, _ := cmd.Flags().GetString("expiration")
@ -133,19 +124,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -156,8 +141,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Cannot create Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
@ -178,11 +161,9 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -197,8 +178,6 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)
@ -219,11 +198,9 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -238,8 +215,6 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key deleted", output)

View file

@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{
Short: "Test the configuration.",
Long: "Run a test of the configuration and exit.",
Run: func(cmd *cobra.Command, args []string) {
_, err := getHeadscaleApp()
_, err := newHeadscaleServerWithConfig()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing")
}

View file

@ -64,11 +64,9 @@ var createNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -79,8 +77,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return
}
machineKey, err := cmd.Flags().GetString("key")
@ -90,8 +86,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return
}
var mkey key.MachinePublic
@ -102,8 +96,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
@ -113,8 +105,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
request := &v1.DebugCreateNodeRequest{
@ -131,8 +121,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
output,
)
return
}
SuccessOutput(response.GetNode(), "Node created", output)

View file

@ -116,11 +116,9 @@ var registerNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -131,8 +129,6 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node key from flag: %s", err),
output,
)
return
}
request := &v1.RegisterNodeRequest{
@ -150,8 +146,6 @@ var registerNodeCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(
@ -169,17 +163,13 @@ var listNodesCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
showTags, err := cmd.Flags().GetBool("tags")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -194,21 +184,15 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
return
}
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -218,8 +202,6 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -243,7 +225,7 @@ var expireNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -286,7 +268,7 @@ var renameNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -335,7 +317,7 @@ var deleteNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -435,7 +417,7 @@ var moveNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -508,7 +490,7 @@ be assigned to nodes.`,
return
}
if confirm {
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -681,7 +663,7 @@ var tagCmd = &cobra.Command{
Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()

View file

@ -1,6 +1,7 @@
package cli
import (
"fmt"
"io"
"os"
@ -30,7 +31,8 @@ var getPolicy = &cobra.Command{
Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) {
ctx, client, conn, cancel := getHeadscaleCLIClient()
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -38,13 +40,13 @@ var getPolicy = &cobra.Command{
response, err := client.GetPolicy(ctx, request)
if err != nil {
log.Fatal().Err(err).Msg("Failed to get the policy")
return
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
}
// TODO(pallabpain): Maybe print this better?
SuccessOutput("", response.GetPolicy(), "hujson")
// This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", response.GetPolicy(), "")
},
}
@ -56,33 +58,28 @@ var setPolicy = &cobra.Command{
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
if err != nil {
log.Fatal().Err(err).Msg("Error opening the policy file")
return
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
log.Fatal().Err(err).Msg("Error reading the policy file")
return
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
}
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil {
log.Fatal().Err(err).Msg("Failed to set ACL Policy")
return
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
}
SuccessOutput(nil, "Policy updated.", "")

View file

@ -60,11 +60,9 @@ var listPreAuthKeys = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -85,8 +83,6 @@ var listPreAuthKeys = &cobra.Command{
if output != "" {
SuccessOutput(response.GetPreAuthKeys(), "", output)
return
}
tableData := pterm.TableData{
@ -134,8 +130,6 @@ var listPreAuthKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -150,20 +144,12 @@ var createPreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
tags, _ := cmd.Flags().GetStringSlice("tags")
log.Trace().
Bool("reusable", reusable).
Bool("ephemeral", ephemeral).
Str("user", user).
Msg("Preparing to create preauthkey")
request := &v1.CreatePreAuthKeyRequest{
User: user,
Reusable: reusable,
@ -180,8 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
@ -192,7 +176,7 @@ var createPreAuthKeyCmd = &cobra.Command{
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -203,8 +187,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
@ -227,11 +209,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -247,8 +227,6 @@ var expirePreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)

View file

@ -9,6 +9,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/tcnksm/go-latest"
)
@ -49,11 +50,6 @@ func initConfig() {
}
}
cfg, err := types.GetHeadscaleConfig()
if err != nil {
log.Fatal().Err(err).Msg("Failed to read headscale configuration")
}
machineOutput := HasMachineOutputFlag()
// If the user has requested a "node" readable format,
@ -62,11 +58,13 @@ func initConfig() {
zerolog.SetGlobalLevel(zerolog.Disabled)
}
if cfg.Log.Format == types.JSONLogFormat {
log.Logger = log.Output(os.Stdout)
}
// logFormat := viper.GetString("log.format")
// if logFormat == types.JSONLogFormat {
// log.Logger = log.Output(os.Stdout)
// }
if !cfg.DisableUpdateCheck && !machineOutput {
disableUpdateCheck := viper.GetBool("disable_check_updates")
if !disableUpdateCheck && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
Version != "dev" {
githubTag := &latest.GithubTag{

View file

@ -64,11 +64,9 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -82,14 +80,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetRoutes(), "", output)
return
}
routes = response.GetRoutes()
@ -103,14 +97,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetRoutes(), "", output)
return
}
routes = response.GetRoutes()
@ -119,8 +109,6 @@ var listRoutesCmd = &cobra.Command{
tableData := routesToPtables(routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -130,8 +118,6 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -150,11 +136,9 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -167,14 +151,10 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}
@ -193,11 +173,9 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -210,14 +188,10 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}
@ -236,11 +210,9 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -253,14 +225,10 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}

View file

@ -1,6 +1,9 @@
package cli
import (
"errors"
"net/http"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)
@ -16,14 +19,14 @@ var serveCmd = &cobra.Command{
return nil
},
Run: func(cmd *cobra.Command, args []string) {
app, err := getHeadscaleApp()
app, err := newHeadscaleServerWithConfig()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing")
}
err = app.Serve()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error starting server")
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.")
}
},
}

View file

@ -44,7 +44,7 @@ var createUserCmd = &cobra.Command{
userName := args[0]
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -63,8 +63,6 @@ var createUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response.GetUser(), "User created", output)
@ -91,7 +89,7 @@ var destroyUserCmd = &cobra.Command{
Name: userName,
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -102,8 +100,6 @@ var destroyUserCmd = &cobra.Command{
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output,
)
return
}
confirm := false
@ -134,8 +130,6 @@ var destroyUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response, "User destroyed", output)
} else {
@ -151,7 +145,7 @@ var listUsersCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -164,14 +158,10 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetUsers(), "", output)
return
}
tableData := pterm.TableData{{"ID", "Name", "Created"}}
@ -192,8 +182,6 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -213,7 +201,7 @@ var renameUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -232,8 +220,6 @@ var renameUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response.GetUser(), "User renamed", output)

View file

@ -23,8 +23,8 @@ const (
SocketWritePermissions = 0o666
)
func getHeadscaleApp() (*hscontrol.Headscale, error) {
cfg, err := types.GetHeadscaleConfig()
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
cfg, err := types.LoadServerConfig()
if err != nil {
return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w",
@ -40,8 +40,8 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := types.GetHeadscaleConfig()
func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := types.LoadCLIConfig()
if err != nil {
log.Fatal().
Err(err).
@ -130,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
return ctx, client, conn, cancel
}
func SuccessOutput(result interface{}, override string, outputFormat string) {
func output(result interface{}, override string, outputFormat string) string {
var jsonBytes []byte
var err error
switch outputFormat {
@ -151,21 +151,26 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
}
default:
// nolint
fmt.Println(override)
return
return override
}
// nolint
fmt.Println(string(jsonBytes))
return string(jsonBytes)
}
// SuccessOutput prints the result to stdout and exits with status code 0.
func SuccessOutput(result interface{}, override string, outputFormat string) {
fmt.Println(output(result, override, outputFormat))
os.Exit(0)
}
// ErrorOutput prints an error message to stderr and exits with status code 1.
func ErrorOutput(errResult error, override string, outputFormat string) {
type errOutput struct {
Error string `json:"error"`
}
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errResult.Error()}, override, outputFormat))
os.Exit(1)
}
func HasMachineOutputFlag() bool {

View file

@ -4,7 +4,6 @@ import (
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
@ -113,60 +112,3 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false)
}
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
}
func (*Suite) TestTLSConfigValidation(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
c.Assert(
tmp,
check.Matches,
".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
writeConfig(c, tmpDir, configYaml)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)
}

View file

@ -8,12 +8,9 @@ This documentation has the goal of showing how a user can use the official Andro
Install the official Tailscale Android client from the [Google Play Store](https://play.google.com/store/apps/details?id=com.tailscale.ipn) or [F-Droid](https://f-droid.org/packages/com.tailscale.ipn/).
Ensure that the installed version is at least 1.30.0, as that is the first release to support custom URLs.
## Configuring the headscale URL
After opening the app:
- Open setting and go into account settings
- In the kebab menu icon (three dots) on the top bar on the right select “Use an alternate server”
- Enter your server URL and follow the instructions
- Open the app and select the settings menu in the upper-right corner
- Tap on `Accounts`
- In the kebab menu icon (three dots) in the upper-right corner select `Use an alternate server`
- Enter your server URL (e.g `https://headscale.example.com`) and follow the instructions

51
docs/apple-client.md Normal file
View file

@ -0,0 +1,51 @@
# Connecting an Apple client
## Goal
This documentation has the goal of showing how a user can use the official iOS and macOS [Tailscale](https://tailscale.com) clients with `headscale`.
!!! info "Instructions on your headscale instance"
An endpoint with information on how to connect your Apple device
is also available at `/apple` on your running instance.
## iOS
### Installation
Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037).
### Configuring the headscale URL
- Open Tailscale and make sure you are _not_ logged in to any account
- Open Settings on the iOS device
- Scroll down to the `third party apps` section, under `Game Center` or `TV Provider`
- Find Tailscale and select it
- If the iOS device was previously logged into Tailscale, switch the `Reset Keychain` toggle to `on`
- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) under `Alternate Coordination Server URL`
- Restart the app by closing it from the iOS app switcher, open the app and select the regular sign in option
_(non-SSO)_. It should open up to the headscale authentication page.
- Enter your credentials and log in. Headscale should now be working on your iOS device.
## macOS
### Installation
Choose one of the available [Tailscale clients for macOS](https://tailscale.com/kb/1065/macos-variants) and install it.
### Configuring the headscale URL
#### Command line
Use Tailscale's login command to connect with your headscale instance (e.g `https://headscale.example.com`):
```
tailscale login --login-server <YOUR_HEADSCALE_URL>
```
#### GUI
- ALT + Click the Tailscale icon in the menu and hover over the Debug menu
- Under `Custom Login Server`, select `Add Account...`
- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) and press `Add Account`
- Follow the login procedure in the browser

View file

@ -5,7 +5,7 @@
Register the node and make it advertise itself as an exit node:
```console
$ sudo tailscale up --login-server https://my-server.com --advertise-exit-node
$ sudo tailscale up --login-server https://headscale.example.com --advertise-exit-node
```
If the node is already registered, it can advertise exit capabilities like this:

View file

@ -1,30 +0,0 @@
# Connecting an iOS client
## Goal
This documentation has the goal of showing how a user can use the official iOS [Tailscale](https://tailscale.com) client with `headscale`.
## Installation
Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037).
Ensure that the installed version is at least 1.38.1, as that is the first release to support alternate control servers.
## Configuring the headscale URL
!!! info "Apple devices"
An endpoint with information on how to connect your Apple devices
(currently macOS only) is available at `/apple` on your running instance.
Ensure that the tailscale app is logged out before proceeding.
Go to iOS settings, scroll down past game center and tv provider to the tailscale app and select it. The headscale URL can be entered into the _"ALTERNATE COORDINATION SERVER URL"_ box.
> **Note**
>
> If the app was previously logged into tailscale, toggle on the _Reset Keychain_ switch.
Restart the app by closing it from the iOS app switcher, open the app and select the regular _Sign in_ option (non-SSO), and it should open up to the headscale authentication page.
Enter your credentials and log in. Headscale should now be working on your iOS device.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 101 KiB

View file

@ -4,39 +4,41 @@
This documentation has the goal of showing how a user can use the official Windows [Tailscale](https://tailscale.com) client with `headscale`.
## Add registry keys
!!! info "Instructions on your headscale instance"
To make the Windows client behave as expected and to run well with `headscale`, two registry keys **must** be set:
- `HKLM:\SOFTWARE\Tailscale IPN\UnattendedMode` must be set to `always` as a `string` type, to allow Tailscale to run properly in the background
- `HKLM:\SOFTWARE\Tailscale IPN\LoginURL` must be set to `<YOUR HEADSCALE URL>` as a `string` type, to ensure Tailscale contacts the correct control server.
You can set these using the Windows Registry Editor:
![windows-registry](./images/windows-registry.png)
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
```
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
```
The Tailscale Windows client has been observed to reset its configuration on logout/reboot and these two keys [resolves that issue](https://github.com/tailscale/tailscale/issues/2798).
For a guide on how to edit registry keys, [check out Computer Hope](https://www.computerhope.com/issues/ch001348.htm).
An endpoint with information on how to connect your Windows device
is also available at `/windows` on your running instance.
## Installation
Download the [Official Windows Client](https://tailscale.com/download/windows) and install it.
When the installation has finished, start Tailscale and log in (you might have to click the icon in the system tray).
## Configuring the headscale URL
The log in should open a browser Window and direct you to your `headscale` instance.
Open a Command Prompt or Powershell and use Tailscale's login command to connect with your headscale instance (e.g
`https://headscale.example.com`):
```
tailscale login --login-server <YOUR_HEADSCALE_URL>
```
Follow the instructions in the opened browser window to finish the configuration.
## Troubleshooting
### Unattended mode
By default, Tailscale's Windows client is only running when the user is logged in. If you want to keep Tailscale running
all the time, please enable "Unattended mode":
- Click on the Tailscale tray icon and select `Preferences`
- Enable `Run unattended`
- Confirm the "Unattended mode" message
See also [Keep Tailscale running when I'm not logged in to my computer](https://tailscale.com/kb/1088/run-unattended)
### Failing node registration
If you are seeing repeated messages like:
```
@ -53,8 +55,7 @@ This typically means that the registry keys above was not set appropriately.
To reset and try again, it is important to do the following:
1. Ensure the registry keys from the previous guide is correctly set.
2. Shut down the Tailscale service (or the client running in the tray)
3. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again.
4. Ensure the Windows node is deleted from headscale (to ensure fresh setup)
5. Start Tailscale on the windows machine and retry the login.
1. Shut down the Tailscale service (or the client running in the tray)
2. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again.
3. Ensure the Windows node is deleted from headscale (to ensure fresh setup)
4. Start Tailscale on the Windows machine and retry the login.

View file

@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1725099143,
"narHash": "sha256-CHgumPZaC7z+WYx72WgaLt2XF0yUVzJS60rO4GZ7ytY=",
"lastModified": 1726238386,
"narHash": "sha256-3//V84fYaGVncFImitM6lSAliRdrGayZLdxWlpcuGk0=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "5629520edecb69630a3f4d17d3d33fc96c13f6fe",
"rev": "01f064c99c792715054dc7a70e4c1626dbbec0c3",
"type": "github"
},
"original": {

View file

@ -32,7 +32,7 @@
# When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files.
vendorHash = "sha256-+8dOxPG/Q+wuHgRwwWqdphHOuop0W9dVyClyQuh7aRc=";
vendorHash = "sha256-jDTOlRzVorbVT8fXwRedU/O2nf15cKABFrdxoTA+1wk=";
subPackages = ["cmd/headscale"];
@ -57,9 +57,11 @@
subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"];
};
golangci-lint = prev.golangci-lint.override {
buildGoModule = buildGo;
};
# Upstream does not override buildGoModule properly,
# importing a specific module, so comment out for now.
# golangci-lint = prev.golangci-lint.override {
# buildGoModule = buildGo;
# };
goreleaser = prev.goreleaser.override {
buildGoModule = buildGo;

View file

@ -18,7 +18,6 @@ import (
"syscall"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/davecgh/go-spew/spew"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -41,7 +40,6 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -95,11 +93,10 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
registrationCache *cache.Cache
authProvider AuthProvider
pollNetMapStreamWG sync.WaitGroup
}
@ -154,16 +151,31 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
})
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
err = app.initOIDC()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
cfg.ServerURL,
&cfg.OIDC,
app.db,
app.registrationCache,
app.nodeNotifier,
app.ipAlloc,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
return nil, err
} else {
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
}
} else {
authProvider = oidcProvider
}
}
app.authProvider = authProvider
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
// TODO(kradalby): revisit why this takes a list.
@ -429,16 +441,15 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallback).Methods(http.MethodGet)
}
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).
Methods(http.MethodGet)
// TODO(kristoffer): move swagger into a package
router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet)
@ -772,7 +783,7 @@ func (h *Headscale) Serve() error {
})
}
default:
trace := log.Trace().Msgf
info := func(msg string) { log.Info().Msg(msg) }
log.Info().
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")
@ -780,55 +791,55 @@ func (h *Headscale) Serve() error {
expireNodeCancel()
h.ephemeralGC.Close()
trace("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers
ctx, cancel := context.WithTimeout(
context.Background(),
types.HTTPShutdownTimeout,
)
trace("shutting down debug http server")
info("shutting down debug http server")
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
log.Error().Err(err).Msg("failed to shutdown prometheus http")
}
trace("shutting down main http server")
info("shutting down main http server")
if err := httpServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown http")
log.Error().Err(err).Msg("failed to shutdown http")
}
trace("shutting down grpc server (socket)")
info("closing node notifier")
h.nodeNotifier.Close()
info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
info("shutting down grpc server (socket)")
grpcSocket.GracefulStop()
if grpcServer != nil {
trace("shutting down grpc server (external)")
info("shutting down grpc server (external)")
grpcServer.GracefulStop()
grpcListener.Close()
}
if tailsqlContext != nil {
trace("shutting down tailsql")
info("shutting down tailsql")
tailsqlContext.Done()
}
trace("closing node notifier")
h.nodeNotifier.Close()
// Close network listeners
trace("closing network listeners")
info("closing network listeners")
debugHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()
// Stop listening (and unlink the socket if unix type):
trace("closing socket listener")
info("closing socket listener")
socketListener.Close()
// Close db connections
trace("closing database connection")
info("closing database connection")
err = h.db.Close()
if err != nil {
log.Error().Err(err).Msg("Failed to close db")
log.Error().Err(err).Msg("failed to close db")
}
log.Info().

View file

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
@ -19,6 +18,11 @@ import (
"tailscale.com/types/ptr"
)
type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
}
func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
@ -66,7 +70,7 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey)
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
@ -105,16 +109,6 @@ func (h *Headscale) handleRegister(
logInfo("Node not found in database, creating new")
givenName, err := h.db.GenerateGivenName(
machineKey,
regReq.Hostinfo.Hostname,
)
if err != nil {
logErr(err, "Failed to generate given name for node")
return
}
// The node did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
@ -122,7 +116,6 @@ func (h *Headscale) handleRegister(
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
@ -175,7 +168,7 @@ func (h *Headscale) handleRegister(
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !regReq.Expiry.IsZero() &&
regReq.Expiry.UTC().Before(now) {
h.handleNodeLogOut(writer, *node, machineKey)
h.handleNodeLogOut(writer, *node)
return
}
@ -183,7 +176,7 @@ func (h *Headscale) handleRegister(
// If node is not expired, and it is register, we have a already accepted this node,
// let it proceed with a valid registration
if !node.IsExpired() {
h.handleNodeWithValidRegistration(writer, *node, machineKey)
h.handleNodeWithValidRegistration(writer, *node)
return
}
@ -196,7 +189,6 @@ func (h *Headscale) handleRegister(
writer,
regReq,
*node,
machineKey,
)
return
@ -209,7 +201,6 @@ func (h *Headscale) handleRegister(
writer,
regReq,
*node,
machineKey,
)
return
@ -354,21 +345,8 @@ func (h *Headscale) handleAuthKey(
} else {
now := time.Now().UTC()
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed to generate given name for node")
return
}
nodeToRegister := types.Node{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
User: pak.User,
MachineKey: machineKey,
@ -410,7 +388,7 @@ func (h *Headscale) handleAuthKey(
}
}
h.db.Write(func(tx *gorm.DB) error {
err = h.db.Write(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak)
})
if err != nil {
@ -471,17 +449,7 @@ func (h *Headscale) handleNewNode(
// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
}
resp.AuthURL = h.authProvider.AuthURL(machineKey)
respBody, err := json.Marshal(resp)
if err != nil {
@ -504,7 +472,6 @@ func (h *Headscale) handleNewNode(
func (h *Headscale) handleNodeLogOut(
writer http.ResponseWriter,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -587,7 +554,6 @@ func (h *Headscale) handleNodeLogOut(
func (h *Headscale) handleNodeWithValidRegistration(
writer http.ResponseWriter,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -633,7 +599,6 @@ func (h *Headscale) handleNodeKeyRefresh(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
@ -709,15 +674,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String())
}
resp.AuthURL = h.authProvider.AuthURL(machineKey)
respBody, err := json.Marshal(resp)
if err != nil {

View file

@ -256,9 +256,6 @@ func NewHeadscaleDatabase(
for item, node := range nodes {
if node.GivenName == "" {
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
node.Hostname,
)
if err != nil {
log.Error().
Caller().
@ -268,7 +265,7 @@ func NewHeadscaleDatabase(
}
err = tx.Model(nodes[item]).Updates(types.Node{
GivenName: normalizedHostname,
GivenName: node.Hostname,
}).Error
if err != nil {
log.Error().
@ -413,6 +410,18 @@ func NewHeadscaleDatabase(
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
ID: "202407191627",
Migrate: func(tx *gorm.DB) error {
err := tx.AutoMigrate(&types.User{})
if err != nil {
return err
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View file

@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
})
}
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
}
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name)
@ -242,9 +228,9 @@ func SetTags(
}
// RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it.
// and renames it. If the name is not unique, it will return an error.
func RenameNode(tx *gorm.DB,
nodeID uint64, newName string,
nodeID types.NodeID, newName string,
) error {
err := util.CheckForFQDNRules(
newName,
@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB,
return fmt.Errorf("renaming node: %w", err)
}
uniq, err := isUnqiueName(tx, newName)
if err != nil {
return fmt.Errorf("checking if name is unique: %w", err)
}
if !uniq {
return fmt.Errorf("name is not unique: %s", newName)
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err)
}
@ -337,7 +332,7 @@ func RegisterNodeFromAuthCallback(
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, fmt.Errorf(
"failed to find user in register node from auth callback, %w",
@ -390,7 +385,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Str("user", node.User.Username()).
Msg("Registering node")
// If the node exists and it already has IP(s), we just save it
@ -406,7 +401,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Name).
Str("user", node.User.Username()).
Msg("Node authorized again")
return &node, nil
@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.IPv4 = ipv4
node.IPv6 = ipv6
if node.GivenName == "" {
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
if err != nil {
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
}
node.GivenName = givenName
}
if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
}
@ -617,18 +621,15 @@ func enableRoutes(tx *gorm.DB,
}
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
suppliedName,
)
if err != nil {
return "", err
if len(suppliedName) > util.LabelHostnameLength {
return "", types.ErrHostnameTooLong
}
if randomSuffix {
// Trim if a hostname will be longer than 63 chars after adding the hash.
trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
if len(normalizedHostname) > trimmedHostnameLength {
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
if len(suppliedName) > trimmedHostnameLength {
suppliedName = suppliedName[:trimmedHostnameLength]
}
suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
@ -636,46 +637,38 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return "", err
}
normalizedHostname += "-" + suffix
suppliedName += "-" + suffix
}
return normalizedHostname, nil
return suppliedName, nil
}
func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
return GenerateGivenName(rx, mkey, suppliedName)
})
func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
nodes := types.Nodes{}
if err := tx.
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
return false, err
}
return len(nodes) == 0, nil
}
func GenerateGivenName(
func ensureUniqueGivenName(
tx *gorm.DB,
mkey key.MachinePublic,
suppliedName string,
name string,
) (string, error) {
givenName, err := generateGivenName(suppliedName, false)
givenName, err := generateGivenName(name, false)
if err != nil {
return "", err
}
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
nodes, err := listNodesByGivenName(tx, givenName)
unique, err := isUnqiueName(tx, givenName)
if err != nil {
return "", err
}
var nodeFound *types.Node
for idx, node := range nodes {
if node.GivenName == givenName {
nodeFound = nodes[idx]
}
}
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
postfixedName, err := generateGivenName(suppliedName, true)
if !unique {
postfixedName, err := generateGivenName(name, true)
if err != nil {
return "", err
}

View file

@ -19,6 +19,7 @@ import (
"github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
}
func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
machineKey2 := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "hostname-1",
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-2", comment)
givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
comment = check.Commentf("Same user, same node, same hostname, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-1", comment)
givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
}
func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) {
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
}
func TestRenameNode(t *testing.T) {
db, err := newTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser("test")
assert.NoError(t, err)
user2, err := db.CreateUser("test2")
assert.NoError(t, err)
node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
node2 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
err = db.DB.Save(&node).Error
assert.NoError(t, err)
err = db.DB.Save(&node2).Error
assert.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node, nil, nil)
if err != nil {
return err
}
_, err = RegisterNode(tx, node2, nil, nil)
return err
})
assert.NoError(t, err)
nodes, err := db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName)
t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName)
assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname)
assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName)
assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname)
assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname)
assert.Len(t, nodes[0].Hostname, 4)
assert.Len(t, nodes[1].Hostname, 4)
assert.Len(t, nodes[0].GivenName, 4)
assert.Len(t, nodes[1].GivenName, 13)
// Nodes can be renamed to a unique name
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "newname")
})
assert.NoError(t, err)
nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")
// Nodes can reuse name that is no longer used
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[1].ID, "test")
})
assert.NoError(t, err)
nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")
assert.Equal(t, nodes[1].GivenName, "test")
// Nodes cannot be renamed to used names
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
}

View file

@ -22,6 +22,7 @@ var (
)
func (hsdb *HSDatabase) CreatePreAuthKey(
// TODO(kradalby): Should be ID, not name
userName string,
reusable bool,
ephemeral bool,
@ -36,13 +37,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
// TODO(kradalby): Should be ID, not name
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, err
}
@ -104,7 +106,7 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUser(tx, userName)
user, err := GetUserByName(tx, userName)
if err != nil {
return nil, err
}

View file

@ -644,7 +644,7 @@ func EnableAutoApprovedRoutes(
Msg("looking up route for autoapproving")
for _, approvedAlias := range routeApprovers {
if approvedAlias == node.User.Name {
if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff

View file

@ -49,7 +49,7 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it.
func DestroyUser(tx *gorm.DB, name string) error {
user, err := GetUser(tx, name)
user, err := GetUserByName(tx, name)
if err != nil {
return ErrUserNotFound
}
@ -90,7 +90,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
// not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, oldName, newName string) error {
var err error
oldUser, err := GetUser(tx, oldName)
oldUser, err := GetUserByName(tx, oldName)
if err != nil {
return err
}
@ -98,7 +98,7 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
if err != nil {
return err
}
_, err = GetUser(tx, newName)
_, err = GetUserByName(tx, newName)
if err == nil {
return ErrUserExists
}
@ -115,13 +115,13 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
return nil
}
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUser(rx, name)
return GetUserByName(rx, name)
})
}
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
func GetUserByName(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error,
@ -133,6 +133,24 @@ func GetUser(tx *gorm.DB, name string) (*types.User, error) {
return &user, nil
}
func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByOIDCIdentifier(rx, id)
})
}
func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "provider_identifier = ?", id); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrUserNotFound
}
return &user, nil
}
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx)
@ -155,7 +173,7 @@ func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
if err != nil {
return nil, err
}
user, err := GetUser(tx, name)
user, err := GetUserByName(tx, name)
if err != nil {
return nil, err
}
@ -180,7 +198,7 @@ func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
if err != nil {
return err
}
user, err := GetUser(tx, username)
user, err := GetUserByName(tx, username)
if err != nil {
return err
}

View file

@ -20,7 +20,7 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
err = db.DestroyUser("test")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
_, err = db.GetUserByName("test")
c.Assert(err, check.NotNil)
}
@ -73,10 +73,10 @@ func (s *Suite) TestRenameUser(c *check.C) {
err = db.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil)
_, err = db.GetUser("test")
_, err = db.GetUserByName("test")
c.Assert(err, check.Equals, ErrUserNotFound)
_, err = db.GetUser("test-renamed")
_, err = db.GetUserByName("test-renamed")
c.Assert(err, check.IsNil)
err = db.RenameUser("test-does-not-exit", "test")

View file

@ -41,7 +41,7 @@ func (api headscaleV1APIServer) GetUser(
ctx context.Context,
request *v1.GetUserRequest,
) (*v1.GetUserResponse, error) {
user, err := api.h.db.GetUser(request.GetName())
user, err := api.h.db.GetUserByName(request.GetName())
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (api headscaleV1APIServer) RenameUser(
return nil, err
}
user, err := api.h.db.GetUser(request.GetNewName())
user, err := api.h.db.GetUserByName(request.GetNewName())
if err != nil {
return nil, err
}
@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode(
tx,
request.GetNodeId(),
types.NodeID(request.GetNodeId()),
request.GetNewName(),
)
if err != nil {
@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy(
case types.PolicyModeDB:
p, err := api.h.db.GetPolicy()
if err != nil {
return nil, err
return nil, fmt.Errorf("loading ACL from database: %w", err)
}
return &v1.GetPolicyResponse{
@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy(
absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
f, err := os.Open(absPath)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
}
defer f.Close()
b, err := io.ReadAll(f)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from file: %w", err)
}
return &v1.GetPolicyResponse{Policy: string(b)}, nil
}
return nil, nil
return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
}
func (api headscaleV1APIServer) SetPolicy(
@ -774,7 +774,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
ctx context.Context,
request *v1.DebugCreateNodeRequest,
) (*v1.DebugCreateNodeResponse, error) {
user, err := api.h.db.GetUser(request.GetUser())
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err
}
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
if err != nil {
return nil, err
}
nodeKey := key.NewNode()
newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,
Expiry: &time.Time{},

View file

@ -8,6 +8,7 @@ import (
"html/template"
"net/http"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
@ -167,12 +168,29 @@ var registerWebAPITemplate = template.Must(
</html>
`))
type AuthProviderWeb struct {
serverURL string
}
func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
return &AuthProviderWeb{
serverURL: serverURL,
}
}
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register/:nkey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.
func (h *Headscale) RegisterWebAPI(
func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -187,7 +205,7 @@ func (h *Headscale) RegisterWebAPI(
[]byte(machineKeyStr),
)
if err != nil {
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
log.Warn().Err(err).Msg("Failed to parse incoming machinekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)

View file

@ -15,7 +15,6 @@ import (
"sync/atomic"
"time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
@ -95,10 +94,10 @@ func generateUserProfiles(
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[string]types.User)
userMap[node.User.Name] = node.User
userMap := make(map[uint]types.User)
userMap[node.User.ID] = node.User
for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there
userMap[peer.User.ID] = peer.User // not worth checking if already is there
}
var profiles []tailcfg.UserProfile
@ -122,32 +121,6 @@ func generateDNSConfig(
dnsConfig := cfg.DNSConfig.Clone()
// if MagicDNS is enabled
if dnsConfig.Proxied {
if cfg.DNSUserNameInMagicDNS {
// Only inject the Search Domain of the current user
// shared nodes should use their full FQDN
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf(
"%s.%s",
node.User.Name,
baseDomain,
),
)
userSet := mapset.NewSet[types.User]()
userSet.Add(node.User)
for _, p := range peers {
userSet.Add(p.User)
}
for _, user := range userSet.ToSlice() {
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
dnsConfig.Routes[dnsRoute] = nil
}
}
}
addNextDNSMetadata(dnsConfig.Resolvers, node)
return dnsConfig
@ -227,7 +200,7 @@ func (m *Mapper) FullMapResponse(
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// ReadOnlyResponse returns a MapResponse for the given node.
// ReadOnlyMapResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse(
@ -552,7 +525,7 @@ func appendPeerChanges(
}
// If there are filter rules present, see if there are any nodes that cannot
// access eachother at all and remove them from the peers.
// access each-other at all and remove them from the peers.
if len(packetFilter) > 0 {
changed = policy.FilterNodesByACL(node, changed, packetFilter)
}
@ -596,7 +569,7 @@ func appendPeerChanges(
} else {
// This is a hack to avoid sending an empty list of packet filters.
// Since tailcfg.PacketFilter has omitempty, any empty PacketFilter will
// be omitted, causing the client to consider it unchange, keeping the
// be omitted, causing the client to consider it unchanged, keeping the
// previous packet filter. Worst case, this can cause a node that previously
// has access to a node to _not_ loose access if an empty (allow none) is sent.
reduced := policy.ReduceFilterRules(node, packetFilter)

View file

@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
@ -28,6 +29,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
Hostname: hostname,
UserID: userid,
User: types.User{
Model: gorm.Model{
ID: userid,
},
Name: username,
},
}
@ -72,14 +76,9 @@ func TestDNSConfigMapResponse(t *testing.T) {
{
magicDNS: true,
want: &tailcfg.DNSConfig{
Routes: map[string][]*dnstype.Resolver{
"shared1.foobar.headscale.net": {},
"shared2.foobar.headscale.net": {},
"shared3.foobar.headscale.net": {},
},
Routes: map[string][]*dnstype.Resolver{},
Domains: []string{
"foobar.headscale.net",
"shared1.foobar.headscale.net",
},
Proxied: true,
},
@ -127,8 +126,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
got := generateDNSConfig(
&types.Config{
DNSConfig: &dnsConfigOrig,
DNSUserNameInMagicDNS: true,
DNSConfig: &dnsConfigOrig,
},
baseDomain,
nodeInShared1,

View file

@ -76,7 +76,7 @@ func tailNode(
keyExpiry = time.Time{}
}
hostname, err := node.GetFQDN(cfg, cfg.BaseDomain)
hostname, err := node.GetFQDN(cfg.BaseDomain)
if err != nil {
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
}

View file

@ -36,6 +36,7 @@ type Notifier struct {
connected *xsync.MapOf[types.NodeID, bool]
b *batcher
cfg *types.Config
closed bool
}
func NewNotifier(cfg *types.Config) *Notifier {
@ -43,6 +44,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: xsync.NewMapOf[types.NodeID, bool](),
cfg: cfg,
closed: false,
}
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
n.b = b
@ -51,9 +53,19 @@ func NewNotifier(cfg *types.Config) *Notifier {
return n
}
// Close stops the batcher inside the notifier.
// Close stops the batcher and closes all channels.
func (n *Notifier) Close() {
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
n.closed = true
n.b.close()
for _, c := range n.nodes {
close(c)
}
}
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
@ -70,6 +82,10 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
if n.closed {
return
}
// If a channel exists, it means the node has opened a new
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
@ -96,6 +112,10 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) b
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
if n.closed {
return true
}
if len(n.nodes) == 0 {
return true
}
@ -154,6 +174,10 @@ func (n *Notifier) NotifyWithIgnore(
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
if n.closed {
return
}
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
n.b.addOrPassthrough(update)
}
@ -170,6 +194,10 @@ func (n *Notifier) NotifyByNodeID(
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
if n.closed {
return
}
if c, ok := n.nodes[nodeID]; ok {
select {
case <-ctx.Done():
@ -205,6 +233,10 @@ func (n *Notifier) sendAll(update types.StateUpdate) {
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
if n.closed {
return
}
for id, c := range n.nodes {
// Whenever an update is sent to all nodes, there is a chance that the node
// has disconnected and the goroutine that was supposed to consume the update

View file

@ -17,8 +17,10 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"gorm.io/gorm"
@ -45,49 +47,77 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *cache.Cache
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
}
func (h *Headscale) initOIDC() error {
func NewAuthProviderOIDC(
ctx context.Context,
serverURL string,
cfg *types.OIDCConfig,
db *db.HSDatabase,
registrationCache *cache.Cache,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
return fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: h.cfg.OIDC.Scope,
}
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
return nil
oauth2Config := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
}
return &AuthProviderOIDC{
serverURL: serverURL,
cfg: cfg,
db: db,
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if h.cfg.OIDC.UseExpiryFromToken {
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
func (a *AuthProviderOIDC) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if a.cfg.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(h.cfg.OIDC.Expiry)
return time.Now().Add(a.cfg.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
// Listens in /register/:mKey.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -108,46 +138,33 @@ func (h *Headscale) RegisterOIDC(
[]byte(machineKeyStr),
)
if err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
util.LogErr(err, "could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
a.registrationCache.Set(
stateStr,
machineKey,
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
for k, v := range h.cfg.OIDC.ExtraParams {
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
@ -170,79 +187,78 @@ var oidcCallbackTemplate = template.Must(
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(
func (a *AuthProviderOIDC) OIDCCallback(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
code, state, err := validateOIDCCallbackParams(req)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
rawIDToken, err := a.getIDTokenForOIDCCallback(req.Context(), code)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
idToken, err := a.verifyIDTokenForOIDCCallback(req.Context(), rawIDToken)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
idTokenExpiry := a.determineTokenExpiration(idToken.Expiry)
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, idToken)
if err != nil {
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
machineKey, nodeExists, err := a.validateNodeForOIDCCallback(
writer,
state,
claims,
&claims,
idTokenExpiry,
)
if err != nil || nodeExists {
return
}
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// register the node if it's new
log.Debug().Msg("Registering new node after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
if err := a.registerNodeForOIDCCallback(user, machineKey, idTokenExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
content, err := renderOIDCCallbackTemplate(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
@ -254,127 +270,57 @@ func (h *Headscale) OIDCCallback(
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
func (a *AuthProviderOIDC) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
code string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
return "", fmt.Errorf("could not exchange code for token: %w", err)
}
log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return "", errNoOIDCIDToken
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
func (a *AuthProviderOIDC) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
util.LogErr(err, "failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
return idToken, nil
}
func extractIDTokenClaims(
writer http.ResponseWriter,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
util.LogErr(err, "Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}
return &claims, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
log.Trace().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedDomains
}
}
@ -387,9 +333,8 @@ func validateOIDCAllowedDomains(
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
writer http.ResponseWriter,
allowedGroups []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
@ -398,14 +343,6 @@ func validateOIDCAllowedGroups(
}
}
log.Trace().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedGroups
}
@ -415,20 +352,12 @@ func validateOIDCAllowedGroups(
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return errOIDCAllowedUsers
}
@ -439,40 +368,21 @@ func validateOIDCAllowedUsers(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
func (a *AuthProviderOIDC) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
claims *types.OIDCClaims,
expiry time.Time,
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
machineKeyIf, machineKeyFound := a.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
}
@ -480,7 +390,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByMachineKey(machineKey)
node, _ := a.db.GetNodeByMachineKey(machineKey)
if node != nil {
log.Trace().
@ -488,20 +398,13 @@ func (h *Headscale) validateNodeForOIDCCallback(
Str("node", node.Hostname).
Msg("node already registered, reauthenticating")
err := h.db.NodeSetExpiry(node.ID, expiry)
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
util.LogErr(err, "Failed to refresh node")
http.Error(
writer,
"Failed to refresh node",
http.StatusInternalServerError,
)
return nil, true, err
}
log.Debug().
Str("node", node.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Time("expiresAt", expiry).
Msg("successfully refreshed node")
var content bytes.Buffer
@ -509,13 +412,6 @@ func (h *Headscale) validateNodeForOIDCCallback(
User: claims.Email,
Verb: "Reauthenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, true, fmt.Errorf("rendering OIDC callback template: %w", err)
}
@ -527,7 +423,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
h.nodeNotifier.NotifyByNodeID(
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
@ -537,7 +433,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil, true, nil
}
@ -545,79 +441,56 @@ func (h *Headscale) validateNodeForOIDCCallback(
return &machineKey, false, nil
}
func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
)
if err != nil {
util.LogErr(err, "couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return "", err
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return userName, nil
}
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
) (*types.User, error) {
user, err := h.db.GetUser(userName)
if errors.Is(err, db.ErrUserNotFound) {
user, err = h.db.CreateUser(userName)
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, fmt.Errorf("creating new user: %w", err)
}
} else if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user"))
if werr != nil {
util.LogErr(err, "Failed to write response")
// This check is for legacy, if the user cannot be found by the OIDC identifier
// look it up by username. This should only be needed once.
if user == nil {
user, err = a.db.GetUserByName(claims.Username)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return nil, fmt.Errorf("find or create user: %w", err)
// if the user is still not found, create a new empty user.
if user == nil {
user = &types.User{}
}
}
user.FromClaim(claims)
err = a.db.DB.Save(user).Error
if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return user, nil
}
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
func (a *AuthProviderOIDC) registerNodeForOIDCCallback(
user *types.User,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
ipv4, ipv6, err := h.ipAlloc.Next()
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
}
if err := h.db.Write(func(tx *gorm.DB) error {
if err := a.db.Write(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
tx,
h.registrationCache,
a.registrationCache,
*machineKey,
// TODO(kradalby): Should be ID, not name
user.Name,
&expiry,
util.RegisterMethodOIDC,
@ -628,36 +501,20 @@ func (h *Headscale) registerNodeForOIDCCallback(
return nil
}); err != nil {
util.LogErr(err, "could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register node"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return err
return fmt.Errorf("could not register node: %w", err)
}
return nil
}
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
claims *types.OIDCClaims,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Authenticated",
}); err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}

View file

@ -59,46 +59,6 @@ func (h *Headscale) WindowsConfigMessage(
}
}
// WindowsRegConfig generates and serves a .reg file configured with the Headscale server address.
func (h *Headscale) WindowsRegConfig(
writer http.ResponseWriter,
req *http.Request,
) {
config := WindowsRegistryConfig{
URL: h.cfg.ServerURL,
}
var content bytes.Buffer
if err := windowsRegTemplate.Execute(&content, config); err != nil {
log.Error().
Str("handler", "WindowsRegConfig").
Err(err).
Msg("Could not render Apple macOS template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Windows registry template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
func (h *Headscale) AppleConfigMessage(
writer http.ResponseWriter,
@ -305,10 +265,6 @@ func (h *Headscale) ApplePlatformConfig(
}
}
type WindowsRegistryConfig struct {
URL string
}
type AppleMobileConfig struct {
UUID uuid.UUID
URL string
@ -320,14 +276,6 @@ type AppleMobilePlatformConfig struct {
URL string
}
var windowsRegTemplate = textTemplate.Must(
textTemplate.New("windowsconfig").Parse(`Windows Registry Editor Version 5.00
[HKEY_LOCAL_MACHINE\SOFTWARE\Tailscale IPN]
"UnattendedMode"="always"
"LoginURL"="{{.URL}}"
`))
var commonTemplate = textTemplate.Must(
textTemplate.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">

View file

@ -737,15 +737,7 @@ func (pol *ACLPolicy) expandUsersFromGroup(
ErrInvalidGroup,
)
}
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
if err != nil {
return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w",
group,
ErrInvalidGroup,
)
}
users = append(users, grp)
users = append(users, group)
}
return users, nil
@ -934,7 +926,7 @@ func (pol *ACLPolicy) TagsOfNode(
}
var found bool
for _, owner := range owners {
if node.User.Name == owner {
if node.User.Username() == owner {
found = true
}
}
@ -958,7 +950,7 @@ func (pol *ACLPolicy) TagsOfNode(
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
var out types.Nodes
for _, node := range nodes {
if node.User.Name == user {
if node.User.Username() == user {
out = append(out, node)
}
}

View file

@ -341,7 +341,7 @@ func TestParsing(t *testing.T) {
],
},
],
}
}
`,
want: []tailcfg.FilterRule{
{
@ -633,25 +633,6 @@ func Test_expandGroup(t *testing.T) {
want: []string{},
wantErr: true,
},
{
name: "Expand emails in group strip domains",
field: field{
pol: ACLPolicy{
Groups: Groups{
"group:admin": []string{
"joe.bar@gmail.com",
"john.doe@yahoo.fr",
},
},
},
},
args: args{
group: "group:admin",
stripEmail: true,
},
want: []string{"joe.bar", "john.doe"},
wantErr: false,
},
{
name: "Expand emails in group",
field: field{
@ -667,7 +648,7 @@ func Test_expandGroup(t *testing.T) {
args: args{
group: "group:admin",
},
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
want: []string{"joe.bar@gmail.com", "john.doe@yahoo.fr"},
wantErr: false,
},
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"math/rand/v2"
"net/http"
"slices"
"sort"
"strings"
"time"
@ -273,6 +274,12 @@ func (m *mapSession) serveLongPoll() {
return
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID) {
m.tracef("node removed, closing stream")
return
}
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()

View file

@ -46,9 +46,7 @@ func (s *Suite) ResetDB(c *check.C) {
Path: tmpDir + "/headscale_test.db",
},
},
OIDC: types.OIDCConfig{
StripEmaildomain: false,
},
OIDC: types.OIDCConfig{},
}
app, err = NewHeadscale(&cfg)

View file

@ -25,17 +25,48 @@
</head>
<body>
<h1>headscale: iOS configuration</h1>
<h2>GUI</h2>
<ol>
<li>
Install the official Tailscale iOS client from the
<a href="https://apps.apple.com/app/tailscale/id1470499037"
>App store</a
>
</li>
<li>
Open Tailscale and make sure you are <i>not</i> logged in to any account
</li>
<li>Open Settings on the iOS device</li>
<li>
Scroll down to the "third party apps" section, under "Game Center" or
"TV Provider"
</li>
<li>
Find Tailscale and select it
<ul>
<li>
If the iOS device was previously logged into Tailscale, switch the
"Reset Keychain" toggle to "on"
</li>
</ul>
</li>
<li>Enter "{{.URL}}" under "Alternate Coordination Server URL"</li>
<li>
Restart the app by closing it from the iOS app switcher, open the app
and select the regular sign in option <i>(non-SSO)</i>. It should open
up to the headscale authentication page.
</li>
<li>
Enter your credentials and log in. Headscale should now be working on
your iOS device
</li>
</ol>
<h1>headscale: macOS configuration</h1>
<h2>Recent Tailscale versions (1.34.0 and higher)</h2>
<p>
Tailscale added Fast User Switching in version 1.34 and you can now use
the new login command to connect to one or more headscale (and Tailscale)
servers. The previously used profiles does not have an effect anymore.
</p>
<h3>Command line</h3>
<h2>Command line</h2>
<p>Use Tailscale's login command to add your profile:</p>
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
<h3>GUI</h3>
<h2>GUI</h2>
<ol>
<li>
ALT + Click the Tailscale icon in the menu and hover over the Debug menu
@ -46,44 +77,7 @@
</li>
<li>Follow the login procedure in the browser</li>
</ol>
<h2>Apple configuration profiles (1.32.0 and lower)</h2>
<p>
This page provides
<a href="https://support.apple.com/guide/mdm/mdm-overview-mdmbf9e668/web"
>configuration profiles</a
>
for the official Tailscale clients for
</p>
<ul>
<li>
<a href="https://apps.apple.com/app/tailscale/id1475387142"
>macOS - AppStore Client</a
>.
</li>
<li>
<a href="https://pkgs.tailscale.com/stable/#macos"
>macOS - Standalone Client</a
>.
</li>
</ul>
<p>
The profiles will configure Tailscale.app to use <code>{{.URL}}</code> as
its control server.
</p>
<h3>Caution</h3>
<p>
You should always download and inspect the profile before installing it:
</p>
<ul>
<li>
for app store client: <code>curl {{.URL}}/apple/macos-app-store</code>
</li>
<li>
for standalone client: <code>curl {{.URL}}/apple/macos-standalone</code>
</li>
</ul>
<h2>Profiles</h2>
<h3>macOS</h3>
<p>
Headscale can be set to the default server by installing a Headscale
configuration profile:
@ -121,50 +115,17 @@
</li>
</ul>
<p>Restart Tailscale.app and log in.</p>
<h1>headscale: iOS configuration</h1>
<h2>Recent Tailscale versions (1.38.1 and higher)</h2>
<h3>Caution</h3>
<p>
Tailscale 1.38.1 on
<a href="https://apps.apple.com/app/tailscale/id1470499037">iOS</a>
added a configuration option to allow user to set an "Alternate
Coordination server". This can be used to connect to your headscale
server.
You should always download and inspect the profile before installing it:
</p>
<h3>GUI</h3>
<ol>
<ul>
<li>
Install the official Tailscale iOS client from the
<a href="https://apps.apple.com/app/tailscale/id1470499037"
>App store</a
>
for app store client: <code>curl {{.URL}}/apple/macos-app-store</code>
</li>
<li>
Open Tailscale and make sure you are <i>not</i> logged in to any account
for standalone client: <code>curl {{.URL}}/apple/macos-standalone</code>
</li>
<li>Open Settings on the iOS device</li>
<li>
Scroll down to the "third party apps" section, under "Game Center" or
"TV Provider"
</li>
<li>
Find Tailscale and select it
<ul>
<li>
If the iOS device was previously logged into Tailscale, switch the
"Reset Keychain" toggle to "on"
</li>
</ul>
</li>
<li>Enter "{{.URL}}" under "Alternate Coordination Server URL"</li>
<li>
Restart the app by closing it from the iOS app switcher, open the app
and select the regular sign in option <i>(non-SSO)</i>. It should open
up to the headscale authentication page.
</li>
<li>
Enter your credentials and log in. Headscale should now be working on
your iOS device
</li>
</ol>
</ul>
</body>
</html>

View file

@ -25,75 +25,21 @@
<body>
<h1>headscale: Windows configuration</h1>
<h2>Recent Tailscale versions (1.34.0 and higher)</h2>
<p>
Tailscale added Fast User Switching in version 1.34 and you can now use
the new login command to connect to one or more headscale (and Tailscale)
servers. The previously used profiles does not have an effect anymore.
</p>
<p>Use Tailscale's login command to add your profile:</p>
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
<h2>Windows registry configuration (1.32.0 and lower)</h2>
<p>
This page provides Windows registry information for the official Windows
Tailscale client.
</p>
<p></p>
<p>
The registry file will configure Tailscale to use <code>{{.URL}}</code> as
its control server.
</p>
<p></p>
<h3>Caution</h3>
<p>
You should always download and inspect the registry file before installing
it:
</p>
<pre><code>curl {{.URL}}/windows/tailscale.reg</code></pre>
<h2>Installation</h2>
<p>
Headscale can be set to the default server by running the registry file:
</p>
<p>
<a href="/windows/tailscale.reg" download="tailscale.reg"
>Windows registry file</a
Download
<a
href="https://tailscale.com/download/windows"
rel="noreferrer noopener"
target="_blank"
>Tailscale for Windows</a
>
and install it.
</p>
<ol>
<li>Download the registry file, then run it</li>
<li>Follow the prompts</li>
<li>Install and run the official windows Tailscale client</li>
<li>
When the installation has finished, start Tailscale, and log in by
clicking the icon in the system tray
</li>
</ol>
<p>Or using REG:</p>
<p>
Open command prompt with Administrator rights. Issue the following
commands to add the required registry entries:
Open a Command Prompt or Powershell and use Tailscale's login command to
connect with headscale:
</p>
<pre>
<code>REG ADD "HKLM\Software\Tailscale IPN" /v UnattendedMode /t REG_SZ /d always
REG ADD "HKLM\Software\Tailscale IPN" /v LoginURL /t REG_SZ /d "{{.URL}}"</code>
</pre>
<p>Or using Powershell</p>
<p>
Open Powershell with Administrator rights. Issue the following commands to
add the required registry entries:
</p>
<pre>
<code>New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value "{{.URL}}"</code>
</pre>
<p>Finally, restart Tailscale and log in.</p>
<p></p>
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
</body>
</html>

View file

@ -71,8 +71,7 @@ type Config struct {
ACMEURL string
ACMEEmail string
DNSConfig *tailcfg.DNSConfig
DNSUserNameInMagicDNS bool
DNSConfig *tailcfg.DNSConfig
UnixSocket string
UnixSocketPermission fs.FileMode
@ -90,12 +89,11 @@ type Config struct {
}
type DNSConfig struct {
MagicDNS bool `mapstructure:"magic_dns"`
BaseDomain string `mapstructure:"base_domain"`
Nameservers Nameservers
SearchDomains []string `mapstructure:"search_domains"`
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
UserNameInMagicDNS bool `mapstructure:"use_username_in_magic_dns"`
MagicDNS bool `mapstructure:"magic_dns"`
BaseDomain string `mapstructure:"base_domain"`
Nameservers Nameservers
SearchDomains []string `mapstructure:"search_domains"`
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
}
type Nameservers struct {
@ -164,7 +162,6 @@ type OIDCConfig struct {
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
StripEmaildomain bool
Expiry time.Duration
UseExpiryFromToken bool
}
@ -212,6 +209,12 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
// It has to be called before LoadServerConfig and LoadCLIConfig.
// The configuration is not validated and the caller should check for errors
// using a validation function.
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
@ -268,7 +271,6 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("database.sqlite.write_ahead_log", true)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
@ -284,14 +286,14 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
}
return nil
}
func validateServerConfig() error {
depr := deprecator{
warns: make(set.Set[string]),
fatals: make(set.Set[string]),
@ -315,8 +317,22 @@ func LoadConfig(path string, isFile bool) error {
depr.warn("dns_config.use_username_in_magic_dns")
depr.warn("dns.use_username_in_magic_dns")
depr.fatal("oidc.strip_email_domain")
depr.fatal("dns.use_username_in_musername_in_magic_dns")
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
depr.Log()
for _, removed := range []string{
"oidc.strip_email_domain",
"dns_config.use_username_in_musername_in_magic_dns",
} {
if viper.IsSet(removed) {
log.Fatal().
Msgf("Fatal config error: %s has been removed. Please remove it from your config file", removed)
}
}
// Collect any validation errors and return them all at once
var errorText string
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
@ -360,12 +376,12 @@ func LoadConfig(path string, isFile bool) error {
if errorText != "" {
// nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
return nil
}
func GetTLSConfig() TLSConfig {
func tlsConfig() TLSConfig {
return TLSConfig{
LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"),
@ -384,7 +400,7 @@ func GetTLSConfig() TLSConfig {
}
}
func GetDERPConfig() DERPConfig {
func derpConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
@ -445,7 +461,7 @@ func GetDERPConfig() DERPConfig {
}
}
func GetLogTailConfig() LogTailConfig {
func logtailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{
@ -453,7 +469,7 @@ func GetLogTailConfig() LogTailConfig {
}
}
func GetPolicyConfig() PolicyConfig {
func policyConfig() PolicyConfig {
policyPath := viper.GetString("policy.path")
policyMode := viper.GetString("policy.mode")
@ -463,7 +479,7 @@ func GetPolicyConfig() PolicyConfig {
}
}
func GetLogConfig() LogConfig {
func logConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
@ -473,9 +489,9 @@ func GetLogConfig() LogConfig {
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
case JSONLogFormat:
logFormat = JSONLogFormat
case "text":
case TextLogFormat:
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
@ -491,7 +507,7 @@ func GetLogConfig() LogConfig {
}
}
func GetDatabaseConfig() DatabaseConfig {
func databaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
@ -543,7 +559,7 @@ func GetDatabaseConfig() DatabaseConfig {
}
}
func DNS() (DNSConfig, error) {
func dns() (DNSConfig, error) {
var dns DNSConfig
// TODO: Use this instead of manually getting settings when
@ -566,21 +582,18 @@ func DNS() (DNSConfig, error) {
if err != nil {
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
}
dns.ExtraRecords = extraRecords
}
dns.UserNameInMagicDNS = viper.GetBool("dns.use_username_in_magic_dns")
return dns, nil
}
// GlobalResolvers returns the global DNS resolvers
// globalResolvers returns the global DNS resolvers
// defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global {
@ -613,11 +626,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
return resolvers
}
// SplitResolvers returns a map of domain to DNS resolvers.
// splitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver
@ -653,7 +666,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
return routes
}
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS {
@ -662,9 +675,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers()
cfg.Resolvers = dns.globalResolvers()
routes := dns.SplitResolvers()
routes := dns.splitResolvers()
cfg.Routes = routes
if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain}
@ -674,7 +687,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
return &cfg
}
func PrefixV4() (*netip.Prefix, error) {
func prefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
if prefixV4Str == "" {
@ -698,7 +711,7 @@ func PrefixV4() (*netip.Prefix, error) {
return &prefixV4, nil
}
func PrefixV6() (*netip.Prefix, error) {
func prefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" {
@ -723,27 +736,41 @@ func PrefixV6() (*netip.Prefix, error) {
return &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
return &Config{
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
}
logConfig := GetLogConfig()
// LoadCLIConfig returns the needed configuration for the CLI client
// of Headscale to connect to a Headscale server.
func LoadCLIConfig() (*Config, error) {
logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := PrefixV4()
return &Config{
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
UnixSocket: viper.GetString("unix_socket"),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
Log: logConfig,
}, nil
}
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
if err := validateServerConfig(); err != nil {
return nil, err
}
logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := prefixV4()
if err != nil {
return nil, err
}
prefix6, err := PrefixV6()
prefix6, err := prefixV6()
if err != nil {
return nil, err
}
@ -760,16 +787,21 @@ func GetHeadscaleConfig() (*Config, error) {
case string(IPAllocationStrategyRandom):
alloc = IPAllocationStrategyRandom
default:
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
return nil, fmt.Errorf(
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
allocStr,
IPAllocationStrategySequential,
IPAllocationStrategyRandom,
)
}
dnsConfig, err := DNS()
dnsConfig, err := dns()
if err != nil {
return nil, err
}
derpConfig := GetDERPConfig()
logTailConfig := GetLogTailConfig()
derpConfig := derpConfig()
logTailConfig := logtailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
oidcClientSecret := viper.GetString("oidc.client_secret")
@ -794,10 +826,11 @@ func GetHeadscaleConfig() (*Config, error) {
// - DERP run on their own domains
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
//
// TODO(kradalby): remove dnsConfig.UserNameInMagicDNS check when removed.
if !dnsConfig.UserNameInMagicDNS && dnsConfig.BaseDomain != "" && strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New("server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.")
if dnsConfig.BaseDomain != "" &&
strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New(
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
)
}
return &Config{
@ -806,7 +839,7 @@ func GetHeadscaleConfig() (*Config, error) {
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
DisableUpdateCheck: false,
PrefixV4: prefix4,
PrefixV6: prefix6,
@ -823,12 +856,11 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout",
),
Database: GetDatabaseConfig(),
Database: databaseConfig(),
TLS: GetTLSConfig(),
TLS: tlsConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
DNSConfig: dnsToTailcfgDNS(dnsConfig),
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
@ -840,15 +872,14 @@ func GetHeadscaleConfig() (*Config, error) {
OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available",
),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
Expiry: func() time.Duration {
// if set to 0, we assume no expiry
if value := viper.GetString("oidc.expiry"); value == "0" {
@ -870,7 +901,7 @@ func GetHeadscaleConfig() (*Config, error) {
LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort,
Policy: GetPolicyConfig(),
Policy: policyConfig(),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
@ -883,17 +914,15 @@ func GetHeadscaleConfig() (*Config, error) {
// TODO(kradalby): Document these settings when more stable
Tuning: Tuning{
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
NodeMapSessionBufferedChanSize: viper.GetInt(
"tuning.node_mapsession_buffered_chan_size",
),
},
}, nil
}
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]
@ -905,14 +934,26 @@ func (d *deprecator) warnWithAlias(newKey, oldKey string) {
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
viper.RegisterAlias(newKey, oldKey)
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.",
oldKey,
newKey,
oldKey,
),
)
}
}
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
func (d *deprecator) fatal(newKey, oldKey string) {
func (d *deprecator) fatal(oldKey string) {
if viper.IsSet(oldKey) {
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.fatals.Add(
fmt.Sprintf(
"The %q configuration key has been removed. Please see the changelog for more details.",
oldKey,
),
)
}
}
@ -920,7 +961,14 @@ func (d *deprecator) fatal(newKey, oldKey string) {
// If the new key is set, a warning is emitted instead.
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.fatals.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
oldKey,
newKey,
oldKey,
),
)
} else if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
}
@ -929,14 +977,26 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
// warn deprecates and adds an option to log a warning if the oldKey is set.
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
oldKey,
newKey,
oldKey,
),
)
}
}
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
func (d *deprecator) warn(oldKey string) {
if viper.IsSet(oldKey) {
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", oldKey))
d.warns.Add(
fmt.Sprintf(
"The %q configuration key is deprecated and has been removed. Please see the changelog for more details.",
oldKey,
),
)
}
}

View file

@ -1,6 +1,8 @@
package types
import (
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-config",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -40,20 +42,19 @@ func TestReadConfig(t *testing.T) {
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
{
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: true,
@ -79,7 +80,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-no-magic",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -97,20 +98,19 @@ func TestReadConfig(t *testing.T) {
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
{
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: false,
@ -136,7 +136,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-in-server-url-err",
configPath: "testdata/base-domain-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
return GetHeadscaleConfig()
return LoadServerConfig()
},
want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
@ -145,7 +145,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-not-in-server-url",
configPath: "testdata/base-domain-not-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -165,7 +165,7 @@ func TestReadConfig(t *testing.T) {
name: "policy-path-is-loaded",
configPath: "testdata/policy-path-is-loaded.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -232,11 +232,10 @@ func TestReadConfigFromEnv(t *testing.T) {
{
name: "unmarshal-dns-full-config",
configEnv: map[string]string{
"HEADSCALE_DNS_MAGIC_DNS": "true",
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
"HEADSCALE_DNS_USE_USERNAME_IN_MAGIC_DNS": "true",
"HEADSCALE_DNS_MAGIC_DNS": "true",
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
// TODO(kradalby): Figure out how to pass these as env vars
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
@ -245,7 +244,7 @@ func TestReadConfigFromEnv(t *testing.T) {
setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings())
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -264,8 +263,7 @@ func TestReadConfigFromEnv(t *testing.T) {
ExtraRecords: []tailcfg.DNSRecord{
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
},
SearchDomains: []string{"test.com", "bar.com"},
UserNameInMagicDNS: true,
SearchDomains: []string{"test.com", "bar.com"},
},
},
}
@ -289,3 +287,49 @@ func TestReadConfigFromEnv(t *testing.T) {
})
}
}
func TestTLSConfigValidation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
t.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
// Check configuration validation errors (1)
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
err = validateServerConfig()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
}

View file

@ -393,7 +393,7 @@ func (node *Node) Proto() *v1.Node {
return nodeProto
}
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
func (node *Node) GetFQDN(baseDomain string) (string, error) {
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
@ -408,19 +408,6 @@ func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
)
}
if cfg.DNSUserNameInMagicDNS {
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
node.User.Name,
baseDomain,
)
}
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"failed to create valid FQDN (%s): %w",

View file

@ -127,76 +127,10 @@ func TestNodeFQDN(t *testing.T) {
tests := []struct {
name string
node Node
cfg Config
domain string
want string
wantErr string
}{
{
name: "all-set-with-username",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-given-name-with-username",
node: Node{
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name-with-username",
node: Node{
GivenName: "test",
User: User{},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node user has no name",
},
{
name: "no-magic-dns-with-username",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: false,
},
DNSUserNameInMagicDNS: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-dnsconfig-with-username",
node: Node{
@ -216,12 +150,6 @@ func TestNodeFQDN(t *testing.T) {
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
},
@ -232,46 +160,16 @@ func TestNodeFQDN(t *testing.T) {
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name",
name: "too-long-username",
node: Node{
GivenName: "test",
User: User{},
GivenName: "useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111",
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
},
{
name: "no-magic-dns",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
cfg: Config{
DNSConfig: &tailcfg.DNSConfig{
Proxied: false,
},
DNSUserNameInMagicDNS: false,
},
domain: "example.com",
want: "test.example.com",
domain: "example.com",
wantErr: "failed to create valid FQDN (useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111.example.com): hostname too long, cannot except 255 ASCII chars",
},
{
name: "no-dnsconfig",
@ -288,7 +186,9 @@ func TestNodeFQDN(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.node.GetFQDN(&tc.cfg, tc.domain)
got, err := tc.node.GetFQDN(tc.domain)
t.Logf("GOT: %q, %q", got, tc.domain)
if (err != nil) && (err.Error() != tc.wantErr) {
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)

View file

@ -1,6 +1,7 @@
package types
import (
"cmp"
"strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -16,19 +17,57 @@ import (
// that contain our machines.
type User struct {
gorm.Model
// Username for the user, is used if email is empty
// Should not be used, please use Username().
Name string `gorm:"unique"`
// Typically a full name of the user
DisplayName string
// Email of the user
// Should not be used, please use Username().
Email string
// Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token
// and is used to lookup the user.
ProviderIdentifier string `gorm:"index"`
// Provider is the origin of the user account,
// same as RegistrationMethod, without authkey.
Provider string
ProfilePicURL string
}
// Username is the main way to get the username of a user,
// it will return the email if it exists, the name if it exists,
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
// Email and OIDCIdentifier will be set when the user has headscale
// enabled with OIDC, which means that there is a domain involved which
// should be used throughout headscale, in information returned to the
// user and the Policy engine.
func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
}
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
// it will return the Username.
func (u *User) DisplayNameOrUsername() string {
return cmp.Or(u.DisplayName, u.Username())
}
// TODO(kradalby): See if we can fill in Gravatar here
func (u *User) profilePicURL() string {
return ""
return u.ProfilePicURL
}
func (u *User) TailscaleUser() *tailcfg.User {
user := tailcfg.User{
ID: tailcfg.UserID(u.ID),
LoginName: u.Name,
DisplayName: u.Name,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
Logins: []tailcfg.LoginID{},
Created: u.CreatedAt,
@ -41,9 +80,9 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
ID: tailcfg.LoginID(u.ID),
// TODO(kradalby): this should reflect registration method.
Provider: "",
LoginName: u.Name,
DisplayName: u.Name,
Provider: u.Provider,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
}
@ -53,8 +92,8 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
return tailcfg.UserProfile{
ID: tailcfg.UserID(u.ID),
LoginName: u.Name,
DisplayName: u.Name,
LoginName: u.Username(),
DisplayName: u.DisplayNameOrUsername(),
ProfilePicURL: u.profilePicURL(),
}
}
@ -66,3 +105,27 @@ func (n *User) Proto() *v1.User {
CreatedAt: timestamppb.New(n.CreatedAt),
}
}
type OIDCClaims struct {
// Sub is the user's unique identifier at the provider.
Sub string `json:"sub"`
// Name is the user's full name.
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
ProfilePictureURL string `json:"picture,omitempty"`
Username string `json:"preferred_username,omitempty"`
}
// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Sub
u.DisplayName = claims.Username
u.Email = claims.Email
u.Name = claims.Username
u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC
}

View file

@ -7,7 +7,6 @@ import (
"regexp"
"strings"
"github.com/spf13/viper"
"go4.org/netipx"
"tailscale.com/util/dnsname"
)
@ -25,38 +24,6 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidUserName = errors.New("invalid user name")
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
strip := viper.GetBool("oidc.strip_email_domain")
return NormalizeToFQDNRules(name, strip)
}
// NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > LabelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidUserName,
)
}
}
return name, nil
}
func CheckForFQDNRules(name string) error {
if len(name) > LabelHostnameLength {
return fmt.Errorf(

View file

@ -7,100 +7,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestNormalizeToFQDNRules(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "user name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "user with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeToFQDNRules() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckForFQDNRules(t *testing.T) {
type args struct {
name string

View file

@ -276,7 +276,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
hsic.WithACLPolicy(&testCase.policy),
)
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErr(t, err)
@ -316,7 +316,7 @@ func TestACLAllowUser80Dst(t *testing.T) {
},
1,
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err)
@ -373,7 +373,7 @@ func TestACLDenyAllPort80(t *testing.T) {
},
4,
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErr(t, err)
@ -417,7 +417,7 @@ func TestACLAllowUserDst(t *testing.T) {
},
2,
)
// defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err)
@ -473,7 +473,7 @@ func TestACLAllowStarDst(t *testing.T) {
},
2,
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err)
@ -534,7 +534,7 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
},
3,
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
user1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErr(t, err)
@ -672,7 +672,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
&testCase.policy,
2,
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
// Since user/users dont matter here, we basically expect that some clients
// will be assigned these ips and that we can pick them up for our own use.
@ -1021,7 +1021,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,

View file

@ -48,7 +48,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -62,7 +62,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
}
err = scenario.CreateHeadscaleEnv(
@ -108,7 +107,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 3,
@ -121,7 +120,6 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
}
@ -276,7 +274,6 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
),
ClientID: "superclient",
ClientSecret: "supersecret",
StripEmaildomain: true,
OnlyStartIfOIDCIsAvailable: true,
}, nil
}

View file

@ -34,7 +34,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -73,7 +73,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
scenario := AuthWebFlowScenario{
Scenario: baseScenario,
}
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"
"testing"
"time"
@ -34,7 +35,7 @@ func TestUserCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
@ -114,7 +115,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
@ -256,7 +257,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
@ -319,7 +320,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 0,
@ -397,7 +398,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user1: 1,
@ -491,7 +492,7 @@ func TestApiKeyCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
@ -659,7 +660,7 @@ func TestNodeTagCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 0,
@ -735,13 +736,7 @@ func TestNodeTagCommand(t *testing.T) {
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
// try to set a wrong tag and retrieve the error
type errOutput struct {
Error string `json:"error"`
}
var errorOutput errOutput
err = executeAndUnmarshal(
headscale,
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
@ -750,10 +745,8 @@ func TestNodeTagCommand(t *testing.T) {
"-t", "wrong-tag",
"--output", "json",
},
&errorOutput,
)
assert.Nil(t, err)
assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'")
assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, len(machineKeys))
@ -792,7 +785,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
@ -842,7 +835,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
@ -905,7 +898,7 @@ func TestNodeCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-user": 0,
@ -1146,7 +1139,7 @@ func TestNodeExpireCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-expire-user": 0,
@ -1273,7 +1266,7 @@ func TestNodeRenameCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"node-rename-command": 0,
@ -1398,18 +1391,17 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5")
// Test failure for too long names
result, err := headscale.Execute(
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[4].GetId()),
"testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890",
strings.Repeat("t", 64),
},
)
assert.Nil(t, err)
assert.Contains(t, result, "not be over 63 chars")
assert.ErrorContains(t, err, "not be over 63 chars")
var listAllAfterRenameAttempt []v1.Node
err = executeAndUnmarshal(
@ -1440,7 +1432,7 @@ func TestNodeMoveCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"old-user": 0,
@ -1536,7 +1528,7 @@ func TestNodeMoveCommand(t *testing.T) {
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
moveToNonExistingNSResult, err := headscale.Execute(
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
@ -1549,11 +1541,9 @@ func TestNodeMoveCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.Contains(
assert.ErrorContains(
t,
moveToNonExistingNSResult,
err,
"user not found",
)
assert.Equal(t, node.GetUser().GetName(), "new-user")
@ -1603,7 +1593,7 @@ func TestPolicyCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"policy-user": 0,
@ -1683,7 +1673,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"policy-user": 1,

View file

@ -6,8 +6,8 @@ import (
)
type ControlServer interface {
Shutdown() error
SaveLog(string) error
Shutdown() (string, string, error)
SaveLog(string) (string, string, error)
SaveProfile(string) error
Execute(command []string) (string, error)
WriteFile(path string, content []byte) error

View file

@ -17,7 +17,7 @@ func TestResolveMagicDNS(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"magicdns1": len(MustTestVersions),
@ -208,7 +208,7 @@ func TestValidateResolvConf(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"resolvconf1": 3,

View file

@ -17,10 +17,10 @@ func SaveLog(
pool *dockertest.Pool,
resource *dockertest.Resource,
basePath string,
) error {
) (string, string, error) {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
return "", "", err
}
var stdout bytes.Buffer
@ -41,28 +41,30 @@ func SaveLog(
},
)
if err != nil {
return err
return "", "", err
}
log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
stdoutPath := path.Join(basePath, resource.Container.Name+".stdout.log")
err = os.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
stdoutPath,
stdout.Bytes(),
filePerm,
)
if err != nil {
return err
return "", "", err
}
stderrPath := path.Join(basePath, resource.Container.Name+".stderr.log")
err = os.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
stderrPath,
stderr.Bytes(),
filePerm,
)
if err != nil {
return err
return "", "", err
}
return nil
return stdoutPath, stderrPath, nil
}

View file

@ -32,7 +32,7 @@ func TestDERPServerScenario(t *testing.T) {
Scenario: baseScenario,
tsicNetworks: map[string]*dockertest.Network{},
}
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),

View file

@ -27,7 +27,7 @@ func TestPingAllByIP(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
@ -71,7 +71,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -109,7 +109,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -228,7 +228,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -313,7 +313,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -427,7 +427,7 @@ func TestPingAllByHostname(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user3": len(MustTestVersions),
@ -476,7 +476,7 @@ func TestTaildrop(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"taildrop": len(MustTestVersions),
@ -637,7 +637,7 @@ func TestExpireNode(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -763,7 +763,7 @@ func TestNodeOnlineStatus(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
@ -878,7 +878,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
@ -954,3 +954,102 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
}
func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": 1,
"user2": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("deletenocrash"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test list all nodes after added otherUser
var nodeList []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeList,
)
assert.Nil(t, err)
assert.Len(t, nodeList, 2)
assert.True(t, nodeList[0].Online)
assert.True(t, nodeList[1].Online)
// Delete the first node, which is online
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"delete",
"--identifier",
// Delete the last added machine
fmt.Sprintf("%d", nodeList[0].Id),
"--output",
"json",
"--force",
},
)
assert.Nil(t, err)
time.Sleep(2 * time.Second)
// Ensure that the node has been deleted, this did not occur due to a panic.
var nodeListAfter []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeListAfter,
)
assert.Nil(t, err)
assert.Len(t, nodeListAfter, 1)
assert.True(t, nodeListAfter[0].Online)
assert.Equal(t, nodeList[1].Id, nodeListAfter[0].Id)
}

View file

@ -398,8 +398,8 @@ func (t *HeadscaleInContainer) hasTLS() bool {
}
// Shutdown stops and cleans up the Headscale container.
func (t *HeadscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
func (t *HeadscaleInContainer) Shutdown() (string, string, error) {
stdoutPath, stderrPath, err := t.SaveLog("/tmp/control")
if err != nil {
log.Printf(
"Failed to save log from control: %s",
@ -458,12 +458,12 @@ func (t *HeadscaleInContainer) Shutdown() error {
t.pool.Purge(t.pgContainer)
}
return t.pool.Purge(t.container)
return stdoutPath, stderrPath, t.pool.Purge(t.container)
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *HeadscaleInContainer) SaveLog(path string) error {
func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) {
return dockertestutil.SaveLog(t.pool, t.container, path)
}

View file

@ -32,7 +32,7 @@ func TestEnablingRoutes(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 3,
@ -254,7 +254,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 3,
@ -826,7 +826,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 1,
@ -968,7 +968,7 @@ func TestAutoApprovedSubRoute2068(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 1,
@ -1059,7 +1059,7 @@ func TestSubnetRouteACL(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
user: 2,

View file

@ -8,6 +8,7 @@ import (
"os"
"sort"
"sync"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -18,6 +19,7 @@ import (
"github.com/ory/dockertest/v3"
"github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"tailscale.com/envknob"
)
@ -187,13 +189,9 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) {
}, nil
}
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
// and networks associated with it.
// In addition, it will save the logs of the ControlServer to `/tmp/control` in the
// environment running the tests.
func (s *Scenario) Shutdown() {
func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
s.controlServers.Range(func(_ string, control ControlServer) bool {
err := control.Shutdown()
stdoutPath, stderrPath, err := control.Shutdown()
if err != nil {
log.Printf(
"Failed to shut down control: %s",
@ -201,6 +199,16 @@ func (s *Scenario) Shutdown() {
)
}
if t != nil {
stdout, err := os.ReadFile(stdoutPath)
assert.NoError(t, err)
assert.NotContains(t, string(stdout), "panic")
stderr, err := os.ReadFile(stderrPath)
assert.NoError(t, err)
assert.NotContains(t, string(stderr), "panic")
}
return true
})
@ -224,6 +232,14 @@ func (s *Scenario) Shutdown() {
// }
}
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
// and networks associated with it.
// In addition, it will save the logs of the ControlServer to `/tmp/control` in the
// environment running the tests.
func (s *Scenario) Shutdown() {
s.ShutdownAssertNoPanics(nil)
}
// Users returns the name of all users associated with the Scenario.
func (s *Scenario) Users() []string {
users := make([]string, 0)

View file

@ -7,7 +7,7 @@ import (
)
// This file is intended to "test the test framework", by proxy it will also test
// some Headcsale/Tailscale stuff, but mostly in very simple ways.
// some Headscale/Tailscale stuff, but mostly in very simple ways.
func IntegrationSkip(t *testing.T) {
t.Helper()
@ -35,7 +35,7 @@ func TestHeadscale(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
t.Run("start-headscale", func(t *testing.T) {
headscale, err := scenario.Headscale()
@ -80,7 +80,7 @@ func TestCreateTailscale(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
scenario.users[user] = &User{
Clients: make(map[string]TailscaleClient),
@ -116,7 +116,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
t.Run("start-headscale", func(t *testing.T) {
headscale, err := scenario.Headscale()

View file

@ -111,7 +111,7 @@ func TestSSHOneUserToAll(t *testing.T) {
},
len(MustTestVersions),
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
@ -176,7 +176,7 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
},
len(MustTestVersions),
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
nsOneClients, err := scenario.ListTailscaleClients("user1")
assertNoErrListClients(t, err)
@ -222,7 +222,7 @@ func TestSSHNoSSHConfigured(t *testing.T) {
},
len(MustTestVersions),
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
@ -271,7 +271,7 @@ func TestSSHIsBlockedInACL(t *testing.T) {
},
len(MustTestVersions),
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
@ -327,7 +327,7 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
},
len(MustTestVersions),
)
defer scenario.Shutdown()
defer scenario.ShutdownAssertNoPanics(t)
ssh1Clients, err := scenario.ListTailscaleClients("user1")
assertNoErrListClients(t, err)

View file

@ -998,7 +998,9 @@ func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *TailscaleInContainer) SaveLog(path string) error {
return dockertestutil.SaveLog(t.pool, t.container, path)
// TODO(kradalby): Assert if tailscale logs contains panics.
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
return err
}
// ReadFile reads a file from the Tailscale container.

View file

@ -10,7 +10,7 @@ repo_name: juanfont/headscale
repo_url: https://github.com/juanfont/headscale
# Copyright
copyright: Copyright &copy; 2023 Headscale authors
copyright: Copyright &copy; 2024 Headscale authors
# Configuration
theme:
@ -55,6 +55,13 @@ theme:
favicon: assets/favicon.png
logo: ./logo/headscale3-dots.svg
# Excludes
exclude_docs: |
/packaging/README.md
/packaging/postinstall.sh
/packaging/postremove.sh
/requirements.txt
# Plugins
plugins:
- search:
@ -139,5 +146,5 @@ nav:
- Remote CLI: remote-cli.md
- Usage:
- Android: android-client.md
- Apple: apple-client.md
- Windows: windows-client.md
- iOS: iOS-client.md