mirror of
https://github.com/juanfont/headscale.git
synced 2024-09-20 07:16:35 +08:00
Compare commits
28 commits
038aca1fe2
...
b86db9a733
Author | SHA1 | Date | |
---|---|---|---|
b86db9a733 | |||
e66d149cee | |||
ba0f844d5e | |||
fb5d40f71b | |||
4850bb0c1c | |||
dc75a4f7bc | |||
d2efc63ca8 | |||
e9b95d2278 | |||
3151c629dc | |||
be2c00d4f8 | |||
8e07f09f3b | |||
14da7c436a | |||
10a72e8d54 | |||
ed78ecda12 | |||
6cbbcd859c | |||
e9d9c0773c | |||
fe68f50328 | |||
c3ef90a7f7 | |||
064c46f2a5 | |||
64319f79ff | |||
4b02dc9565 | |||
7be8796d87 | |||
99f18f9cd9 | |||
c3b260a6f7 | |||
60b94b0467 | |||
bac7ea67f4 | |||
5597edac1e | |||
8a3a0fee3c |
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
|
@ -52,6 +52,7 @@ jobs:
|
|||
- TestExpireNode
|
||||
- TestNodeOnlineStatus
|
||||
- TestPingAllByIPManyUpDown
|
||||
- Test2118DeletingOnlineNodePanics
|
||||
- TestEnablingRoutes
|
||||
- TestHASubnetRouterFailover
|
||||
- TestEnableDisableAutoApprovedRoute
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -22,6 +22,7 @@ dist/
|
|||
/headscale
|
||||
config.json
|
||||
config.yaml
|
||||
config*.yaml
|
||||
derp.yaml
|
||||
*.hujson
|
||||
*.key
|
||||
|
|
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -1,8 +1,22 @@
|
|||
# CHANGELOG
|
||||
|
||||
## 0.23.0 (2023-XX-XX)
|
||||
## Next
|
||||
|
||||
This release is mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
|
||||
### BREAKING
|
||||
|
||||
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||
- Having usernames in magic DNS is no longer possible.
|
||||
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
|
||||
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
|
||||
- User has been extended to store username, display name, profile picture url and email.
|
||||
- These fields are forwarded to the client, and shows up nicely in the user switcher.
|
||||
- These fields can be made available via the API/CLI for non-OIDC users in the future.
|
||||
|
||||
## 0.23.0 (2023-09-18)
|
||||
|
||||
This release was intended to be mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project.
|
||||
However, as you all have noticed, it turned out to become a much larger, much longer release cycle than anticipated. It has ended up to be a release with a lot of rewrites and changes to the code base and functionality of Headscale, cleaning up a lot of technical debt and introducing a lot of improvements. This does come with some breaking changes,
|
||||
|
||||
**Please remember to always back up your database between versions**
|
||||
|
||||
|
@ -16,7 +30,7 @@ The [“poller”, or streaming logic](https://github.com/juanfont/headscale/blo
|
|||
|
||||
Headscale now supports sending “delta” updates, thanks to the new mapper and poller logic, allowing us to only inform nodes about new nodes, changed nodes and removed nodes. Previously we sent the entire state of the network every time an update was due.
|
||||
|
||||
While we have a pretty good [test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) for validating our changes, we have rewritten over [10000 lines of code](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...main) and bugs are expected. We need help testing this release. In addition, while we think the performance should in general be better, there might be regressions in parts of the platform, particularly where we prioritised correctness over speed.
|
||||
While we have a pretty good [test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) for validating our changes, the changes came down to [284 changed files with 32,316 additions and 24,245 deletions](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...ed78ecd) and bugs are expected. We need help testing this release. In addition, while we think the performance should in general be better, there might be regressions in parts of the platform, particularly where we prioritised correctness over speed.
|
||||
|
||||
There are also several bugfixes that has been encountered and fixed as part of implementing these changes, particularly
|
||||
after improving the test harness as part of adopting [#1460](https://github.com/juanfont/headscale/pull/1460).
|
||||
|
@ -72,6 +86,9 @@ after improving the test harness as part of adopting [#1460](https://github.com/
|
|||
- Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792)
|
||||
- Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764)
|
||||
- Make sure integration tests cover postgres for all scenarios
|
||||
- CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109)
|
||||
- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109)
|
||||
- Fix issue where shutting down headscale would hang [#2113](https://github.com/juanfont/headscale/pull/2113)
|
||||
|
||||
## 0.22.3 (2023-05-12)
|
||||
|
||||
|
|
18
README.md
18
README.md
|
@ -62,15 +62,15 @@ buttons available in the repo.
|
|||
|
||||
## Client OS support
|
||||
|
||||
| OS | Supports headscale |
|
||||
| ------- | --------------------------------------------------------- |
|
||||
| Linux | Yes |
|
||||
| OpenBSD | Yes |
|
||||
| FreeBSD | Yes |
|
||||
| macOS | Yes (see `/apple` on your headscale for more information) |
|
||||
| Windows | Yes [docs](./docs/windows-client.md) |
|
||||
| Android | Yes [docs](./docs/android-client.md) |
|
||||
| iOS | Yes [docs](./docs/iOS-client.md) |
|
||||
| OS | Supports headscale |
|
||||
| ------- | -------------------------------------------------------------------------------------------------- |
|
||||
| Linux | Yes |
|
||||
| OpenBSD | Yes |
|
||||
| FreeBSD | Yes |
|
||||
| Windows | Yes (see [docs](./docs/windows-client.md) and `/windows` on your headscale for more information) |
|
||||
| Android | Yes (see [docs](./docs/android-client.md)) |
|
||||
| macOS | Yes (see [docs](./docs/apple-client.md#macos) and `/apple` on your headscale for more information) |
|
||||
| iOS | Yes (see [docs](./docs/apple-client.md#ios) and `/apple` on your headscale for more information) |
|
||||
|
||||
## Running headscale
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ var listAPIKeys = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -67,14 +67,10 @@ var listAPIKeys = &cobra.Command{
|
|||
fmt.Sprintf("Error getting the list of keys: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetApiKeys(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{
|
||||
|
@ -102,8 +98,6 @@ var listAPIKeys = &cobra.Command{
|
|||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -119,9 +113,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
log.Trace().
|
||||
Msg("Preparing to create ApiKey")
|
||||
|
||||
request := &v1.CreateApiKeyRequest{}
|
||||
|
||||
durationStr, _ := cmd.Flags().GetString("expiration")
|
||||
|
@ -133,19 +124,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
|
|||
fmt.Sprintf("Could not parse duration: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiration := time.Now().UTC().Add(time.Duration(duration))
|
||||
|
||||
log.Trace().
|
||||
Dur("expiration", time.Duration(duration)).
|
||||
Msg("expiration has been set")
|
||||
|
||||
request.Expiration = timestamppb.New(expiration)
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -156,8 +141,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
|
|||
fmt.Sprintf("Cannot create Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
|
||||
|
@ -178,11 +161,9 @@ var expireAPIKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -197,8 +178,6 @@ var expireAPIKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
|
@ -219,11 +198,9 @@ var deleteAPIKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -238,8 +215,6 @@ var deleteAPIKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key deleted", output)
|
||||
|
|
|
@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{
|
|||
Short: "Test the configuration.",
|
||||
Long: "Run a test of the configuration and exit.",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
_, err := getHeadscaleApp()
|
||||
_, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
||||
}
|
||||
|
|
|
@ -64,11 +64,9 @@ var createNodeCmd = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -79,8 +77,6 @@ var createNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting node from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
machineKey, err := cmd.Flags().GetString("key")
|
||||
|
@ -90,8 +86,6 @@ var createNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var mkey key.MachinePublic
|
||||
|
@ -102,8 +96,6 @@ var createNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := cmd.Flags().GetStringSlice("route")
|
||||
|
@ -113,8 +105,6 @@ var createNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting routes from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
request := &v1.DebugCreateNodeRequest{
|
||||
|
@ -131,8 +121,6 @@ var createNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetNode(), "Node created", output)
|
||||
|
|
|
@ -116,11 +116,9 @@ var registerNodeCmd = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -131,8 +129,6 @@ var registerNodeCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting node key from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
request := &v1.RegisterNodeRequest{
|
||||
|
@ -150,8 +146,6 @@ var registerNodeCmd = &cobra.Command{
|
|||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(
|
||||
|
@ -169,17 +163,13 @@ var listNodesCmd = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
showTags, err := cmd.Flags().GetBool("tags")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -194,21 +184,15 @@ var listNodesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetNodes(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
|
@ -218,8 +202,6 @@ var listNodesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -243,7 +225,7 @@ var expireNodeCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -286,7 +268,7 @@ var renameNodeCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -335,7 +317,7 @@ var deleteNodeCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -435,7 +417,7 @@ var moveNodeCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -508,7 +490,7 @@ be assigned to nodes.`,
|
|||
return
|
||||
}
|
||||
if confirm {
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -681,7 +663,7 @@ var tagCmd = &cobra.Command{
|
|||
Aliases: []string{"tags", "t"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
|
@ -30,7 +31,8 @@ var getPolicy = &cobra.Command{
|
|||
Short: "Print the current ACL Policy",
|
||||
Aliases: []string{"show", "view", "fetch"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -38,13 +40,13 @@ var getPolicy = &cobra.Command{
|
|||
|
||||
response, err := client.GetPolicy(ctx, request)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to get the policy")
|
||||
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
|
||||
}
|
||||
|
||||
// TODO(pallabpain): Maybe print this better?
|
||||
SuccessOutput("", response.GetPolicy(), "hujson")
|
||||
// This does not pass output as we dont support yaml, json or json-line
|
||||
// output for this command. It is HuJSON already.
|
||||
SuccessOutput("", response.GetPolicy(), "")
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -56,33 +58,28 @@ var setPolicy = &cobra.Command{
|
|||
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
|
||||
Aliases: []string{"put", "update"},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
policyPath, _ := cmd.Flags().GetString("file")
|
||||
|
||||
f, err := os.Open(policyPath)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Error opening the policy file")
|
||||
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Error reading the policy file")
|
||||
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
|
||||
}
|
||||
|
||||
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to set ACL Policy")
|
||||
|
||||
return
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
}
|
||||
|
||||
SuccessOutput(nil, "Policy updated.", "")
|
||||
|
|
|
@ -60,11 +60,9 @@ var listPreAuthKeys = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -85,8 +83,6 @@ var listPreAuthKeys = &cobra.Command{
|
|||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetPreAuthKeys(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{
|
||||
|
@ -134,8 +130,6 @@ var listPreAuthKeys = &cobra.Command{
|
|||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -150,20 +144,12 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
reusable, _ := cmd.Flags().GetBool("reusable")
|
||||
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
|
||||
tags, _ := cmd.Flags().GetStringSlice("tags")
|
||||
|
||||
log.Trace().
|
||||
Bool("reusable", reusable).
|
||||
Bool("ephemeral", ephemeral).
|
||||
Str("user", user).
|
||||
Msg("Preparing to create preauthkey")
|
||||
|
||||
request := &v1.CreatePreAuthKeyRequest{
|
||||
User: user,
|
||||
Reusable: reusable,
|
||||
|
@ -180,8 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Could not parse duration: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiration := time.Now().UTC().Add(time.Duration(duration))
|
||||
|
@ -192,7 +176,7 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
|
||||
request.Expiration = timestamppb.New(expiration)
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -203,8 +187,6 @@ var createPreAuthKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
|
||||
|
@ -227,11 +209,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
user, err := cmd.Flags().GetString("user")
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -247,8 +227,6 @@ var expirePreAuthKeyCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response, "Key expired", output)
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/tcnksm/go-latest"
|
||||
)
|
||||
|
||||
|
@ -49,11 +50,6 @@ func initConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
cfg, err := types.GetHeadscaleConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to read headscale configuration")
|
||||
}
|
||||
|
||||
machineOutput := HasMachineOutputFlag()
|
||||
|
||||
// If the user has requested a "node" readable format,
|
||||
|
@ -62,11 +58,13 @@ func initConfig() {
|
|||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
}
|
||||
|
||||
if cfg.Log.Format == types.JSONLogFormat {
|
||||
log.Logger = log.Output(os.Stdout)
|
||||
}
|
||||
// logFormat := viper.GetString("log.format")
|
||||
// if logFormat == types.JSONLogFormat {
|
||||
// log.Logger = log.Output(os.Stdout)
|
||||
// }
|
||||
|
||||
if !cfg.DisableUpdateCheck && !machineOutput {
|
||||
disableUpdateCheck := viper.GetBool("disable_check_updates")
|
||||
if !disableUpdateCheck && !machineOutput {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||
Version != "dev" {
|
||||
githubTag := &latest.GithubTag{
|
||||
|
|
|
@ -64,11 +64,9 @@ var listRoutesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -82,14 +80,10 @@ var listRoutesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetRoutes(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routes = response.GetRoutes()
|
||||
|
@ -103,14 +97,10 @@ var listRoutesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetRoutes(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
routes = response.GetRoutes()
|
||||
|
@ -119,8 +109,6 @@ var listRoutesCmd = &cobra.Command{
|
|||
tableData := routesToPtables(routes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
|
@ -130,8 +118,6 @@ var listRoutesCmd = &cobra.Command{
|
|||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -150,11 +136,9 @@ var enableRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -167,14 +151,10 @@ var enableRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -193,11 +173,9 @@ var disableRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -210,14 +188,10 @@ var disableRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -236,11 +210,9 @@ var deleteRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error getting machine id from flag: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -253,14 +225,10 @@ var deleteRouteCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response, "", output)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
@ -16,14 +19,14 @@ var serveCmd = &cobra.Command{
|
|||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
app, err := getHeadscaleApp()
|
||||
app, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msg("Error initializing")
|
||||
}
|
||||
|
||||
err = app.Serve()
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msg("Error starting server")
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.")
|
||||
}
|
||||
},
|
||||
}
|
|
@ -44,7 +44,7 @@ var createUserCmd = &cobra.Command{
|
|||
|
||||
userName := args[0]
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -63,8 +63,6 @@ var createUserCmd = &cobra.Command{
|
|||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User created", output)
|
||||
|
@ -91,7 +89,7 @@ var destroyUserCmd = &cobra.Command{
|
|||
Name: userName,
|
||||
}
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -102,8 +100,6 @@ var destroyUserCmd = &cobra.Command{
|
|||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
confirm := false
|
||||
|
@ -134,8 +130,6 @@ var destroyUserCmd = &cobra.Command{
|
|||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
SuccessOutput(response, "User destroyed", output)
|
||||
} else {
|
||||
|
@ -151,7 +145,7 @@ var listUsersCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -164,14 +158,10 @@ var listUsersCmd = &cobra.Command{
|
|||
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
SuccessOutput(response.GetUsers(), "", output)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tableData := pterm.TableData{{"ID", "Name", "Created"}}
|
||||
|
@ -192,8 +182,6 @@ var listUsersCmd = &cobra.Command{
|
|||
fmt.Sprintf("Failed to render pterm table: %s", err),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -213,7 +201,7 @@ var renameUserCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
output, _ := cmd.Flags().GetString("output")
|
||||
|
||||
ctx, client, conn, cancel := getHeadscaleCLIClient()
|
||||
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -232,8 +220,6 @@ var renameUserCmd = &cobra.Command{
|
|||
),
|
||||
output,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.GetUser(), "User renamed", output)
|
||||
|
|
|
@ -23,8 +23,8 @@ const (
|
|||
SocketWritePermissions = 0o666
|
||||
)
|
||||
|
||||
func getHeadscaleApp() (*hscontrol.Headscale, error) {
|
||||
cfg, err := types.GetHeadscaleConfig()
|
||||
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
|
||||
cfg, err := types.LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to load configuration while creating headscale instance: %w",
|
||||
|
@ -40,8 +40,8 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
|
|||
return app, nil
|
||||
}
|
||||
|
||||
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||
cfg, err := types.GetHeadscaleConfig()
|
||||
func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
|
||||
cfg, err := types.LoadCLIConfig()
|
||||
if err != nil {
|
||||
log.Fatal().
|
||||
Err(err).
|
||||
|
@ -130,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
|
|||
return ctx, client, conn, cancel
|
||||
}
|
||||
|
||||
func SuccessOutput(result interface{}, override string, outputFormat string) {
|
||||
func output(result interface{}, override string, outputFormat string) string {
|
||||
var jsonBytes []byte
|
||||
var err error
|
||||
switch outputFormat {
|
||||
|
@ -151,21 +151,26 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
|
|||
}
|
||||
default:
|
||||
// nolint
|
||||
fmt.Println(override)
|
||||
|
||||
return
|
||||
return override
|
||||
}
|
||||
|
||||
// nolint
|
||||
fmt.Println(string(jsonBytes))
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// SuccessOutput prints the result to stdout and exits with status code 0.
|
||||
func SuccessOutput(result interface{}, override string, outputFormat string) {
|
||||
fmt.Println(output(result, override, outputFormat))
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// ErrorOutput prints an error message to stderr and exits with status code 1.
|
||||
func ErrorOutput(errResult error, override string, outputFormat string) {
|
||||
type errOutput struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errResult.Error()}, override, outputFormat))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func HasMachineOutputFlag() bool {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
|
@ -113,60 +112,3 @@ func (*Suite) TestConfigLoading(c *check.C) {
|
|||
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
|
||||
c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false)
|
||||
}
|
||||
|
||||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
||||
// Populate a custom config file
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
err := os.WriteFile(configFile, configYaml, 0o600)
|
||||
if err != nil {
|
||||
c.Fatalf("Couldn't write file %s", configFile)
|
||||
}
|
||||
}
|
||||
|
||||
func (*Suite) TestTLSConfigValidation(c *check.C) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
// defer os.RemoveAll(tmpDir)
|
||||
configYaml := []byte(`---
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: ""
|
||||
tls_cert_path: abc.pem
|
||||
noise:
|
||||
private_key_path: noise_private.key`)
|
||||
writeConfig(c, tmpDir, configYaml)
|
||||
|
||||
// Check configuration validation errors (1)
|
||||
err = types.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.NotNil)
|
||||
// check.Matches can not handle multiline strings
|
||||
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
|
||||
c.Assert(
|
||||
tmp,
|
||||
check.Matches,
|
||||
".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*",
|
||||
)
|
||||
c.Assert(
|
||||
tmp,
|
||||
check.Matches,
|
||||
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
|
||||
)
|
||||
c.Assert(
|
||||
tmp,
|
||||
check.Matches,
|
||||
".*Fatal config error: server_url must start with https:// or http://.*",
|
||||
)
|
||||
|
||||
// Check configuration validation errors (2)
|
||||
configYaml = []byte(`---
|
||||
noise:
|
||||
private_key_path: noise_private.key
|
||||
server_url: http://127.0.0.1:8080
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||
`)
|
||||
writeConfig(c, tmpDir, configYaml)
|
||||
err = types.LoadConfig(tmpDir, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
|
|
@ -8,12 +8,9 @@ This documentation has the goal of showing how a user can use the official Andro
|
|||
|
||||
Install the official Tailscale Android client from the [Google Play Store](https://play.google.com/store/apps/details?id=com.tailscale.ipn) or [F-Droid](https://f-droid.org/packages/com.tailscale.ipn/).
|
||||
|
||||
Ensure that the installed version is at least 1.30.0, as that is the first release to support custom URLs.
|
||||
|
||||
## Configuring the headscale URL
|
||||
|
||||
After opening the app:
|
||||
|
||||
- Open setting and go into account settings
|
||||
- In the kebab menu icon (three dots) on the top bar on the right select “Use an alternate server”
|
||||
- Enter your server URL and follow the instructions
|
||||
- Open the app and select the settings menu in the upper-right corner
|
||||
- Tap on `Accounts`
|
||||
- In the kebab menu icon (three dots) in the upper-right corner select `Use an alternate server`
|
||||
- Enter your server URL (e.g `https://headscale.example.com`) and follow the instructions
|
||||
|
|
51
docs/apple-client.md
Normal file
51
docs/apple-client.md
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Connecting an Apple client
|
||||
|
||||
## Goal
|
||||
|
||||
This documentation has the goal of showing how a user can use the official iOS and macOS [Tailscale](https://tailscale.com) clients with `headscale`.
|
||||
|
||||
!!! info "Instructions on your headscale instance"
|
||||
|
||||
An endpoint with information on how to connect your Apple device
|
||||
is also available at `/apple` on your running instance.
|
||||
|
||||
## iOS
|
||||
|
||||
### Installation
|
||||
|
||||
Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037).
|
||||
|
||||
### Configuring the headscale URL
|
||||
|
||||
- Open Tailscale and make sure you are _not_ logged in to any account
|
||||
- Open Settings on the iOS device
|
||||
- Scroll down to the `third party apps` section, under `Game Center` or `TV Provider`
|
||||
- Find Tailscale and select it
|
||||
- If the iOS device was previously logged into Tailscale, switch the `Reset Keychain` toggle to `on`
|
||||
- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) under `Alternate Coordination Server URL`
|
||||
- Restart the app by closing it from the iOS app switcher, open the app and select the regular sign in option
|
||||
_(non-SSO)_. It should open up to the headscale authentication page.
|
||||
- Enter your credentials and log in. Headscale should now be working on your iOS device.
|
||||
|
||||
## macOS
|
||||
|
||||
### Installation
|
||||
|
||||
Choose one of the available [Tailscale clients for macOS](https://tailscale.com/kb/1065/macos-variants) and install it.
|
||||
|
||||
### Configuring the headscale URL
|
||||
|
||||
#### Command line
|
||||
|
||||
Use Tailscale's login command to connect with your headscale instance (e.g `https://headscale.example.com`):
|
||||
|
||||
```
|
||||
tailscale login --login-server <YOUR_HEADSCALE_URL>
|
||||
```
|
||||
|
||||
#### GUI
|
||||
|
||||
- ALT + Click the Tailscale icon in the menu and hover over the Debug menu
|
||||
- Under `Custom Login Server`, select `Add Account...`
|
||||
- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) and press `Add Account`
|
||||
- Follow the login procedure in the browser
|
|
@ -5,7 +5,7 @@
|
|||
Register the node and make it advertise itself as an exit node:
|
||||
|
||||
```console
|
||||
$ sudo tailscale up --login-server https://my-server.com --advertise-exit-node
|
||||
$ sudo tailscale up --login-server https://headscale.example.com --advertise-exit-node
|
||||
```
|
||||
|
||||
If the node is already registered, it can advertise exit capabilities like this:
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
# Connecting an iOS client
|
||||
|
||||
## Goal
|
||||
|
||||
This documentation has the goal of showing how a user can use the official iOS [Tailscale](https://tailscale.com) client with `headscale`.
|
||||
|
||||
## Installation
|
||||
|
||||
Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037).
|
||||
|
||||
Ensure that the installed version is at least 1.38.1, as that is the first release to support alternate control servers.
|
||||
|
||||
## Configuring the headscale URL
|
||||
|
||||
!!! info "Apple devices"
|
||||
|
||||
An endpoint with information on how to connect your Apple devices
|
||||
(currently macOS only) is available at `/apple` on your running instance.
|
||||
|
||||
Ensure that the tailscale app is logged out before proceeding.
|
||||
|
||||
Go to iOS settings, scroll down past game center and tv provider to the tailscale app and select it. The headscale URL can be entered into the _"ALTERNATE COORDINATION SERVER URL"_ box.
|
||||
|
||||
> **Note**
|
||||
>
|
||||
> If the app was previously logged into tailscale, toggle on the _Reset Keychain_ switch.
|
||||
|
||||
Restart the app by closing it from the iOS app switcher, open the app and select the regular _Sign in_ option (non-SSO), and it should open up to the headscale authentication page.
|
||||
|
||||
Enter your credentials and log in. Headscale should now be working on your iOS device.
|
Binary file not shown.
Before Width: | Height: | Size: 101 KiB |
|
@ -4,39 +4,41 @@
|
|||
|
||||
This documentation has the goal of showing how a user can use the official Windows [Tailscale](https://tailscale.com) client with `headscale`.
|
||||
|
||||
## Add registry keys
|
||||
!!! info "Instructions on your headscale instance"
|
||||
|
||||
To make the Windows client behave as expected and to run well with `headscale`, two registry keys **must** be set:
|
||||
|
||||
- `HKLM:\SOFTWARE\Tailscale IPN\UnattendedMode` must be set to `always` as a `string` type, to allow Tailscale to run properly in the background
|
||||
- `HKLM:\SOFTWARE\Tailscale IPN\LoginURL` must be set to `<YOUR HEADSCALE URL>` as a `string` type, to ensure Tailscale contacts the correct control server.
|
||||
|
||||
You can set these using the Windows Registry Editor:
|
||||
|
||||
![windows-registry](./images/windows-registry.png)
|
||||
|
||||
Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"):
|
||||
|
||||
```
|
||||
New-Item -Path "HKLM:\SOFTWARE\Tailscale IPN"
|
||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
|
||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL
|
||||
```
|
||||
|
||||
The Tailscale Windows client has been observed to reset its configuration on logout/reboot and these two keys [resolves that issue](https://github.com/tailscale/tailscale/issues/2798).
|
||||
|
||||
For a guide on how to edit registry keys, [check out Computer Hope](https://www.computerhope.com/issues/ch001348.htm).
|
||||
An endpoint with information on how to connect your Windows device
|
||||
is also available at `/windows` on your running instance.
|
||||
|
||||
## Installation
|
||||
|
||||
Download the [Official Windows Client](https://tailscale.com/download/windows) and install it.
|
||||
|
||||
When the installation has finished, start Tailscale and log in (you might have to click the icon in the system tray).
|
||||
## Configuring the headscale URL
|
||||
|
||||
The log in should open a browser Window and direct you to your `headscale` instance.
|
||||
Open a Command Prompt or Powershell and use Tailscale's login command to connect with your headscale instance (e.g
|
||||
`https://headscale.example.com`):
|
||||
|
||||
```
|
||||
tailscale login --login-server <YOUR_HEADSCALE_URL>
|
||||
```
|
||||
|
||||
Follow the instructions in the opened browser window to finish the configuration.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Unattended mode
|
||||
|
||||
By default, Tailscale's Windows client is only running when the user is logged in. If you want to keep Tailscale running
|
||||
all the time, please enable "Unattended mode":
|
||||
|
||||
- Click on the Tailscale tray icon and select `Preferences`
|
||||
- Enable `Run unattended`
|
||||
- Confirm the "Unattended mode" message
|
||||
|
||||
See also [Keep Tailscale running when I'm not logged in to my computer](https://tailscale.com/kb/1088/run-unattended)
|
||||
|
||||
### Failing node registration
|
||||
|
||||
If you are seeing repeated messages like:
|
||||
|
||||
```
|
||||
|
@ -53,8 +55,7 @@ This typically means that the registry keys above was not set appropriately.
|
|||
|
||||
To reset and try again, it is important to do the following:
|
||||
|
||||
1. Ensure the registry keys from the previous guide is correctly set.
|
||||
2. Shut down the Tailscale service (or the client running in the tray)
|
||||
3. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again.
|
||||
4. Ensure the Windows node is deleted from headscale (to ensure fresh setup)
|
||||
5. Start Tailscale on the windows machine and retry the login.
|
||||
1. Shut down the Tailscale service (or the client running in the tray)
|
||||
2. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again.
|
||||
3. Ensure the Windows node is deleted from headscale (to ensure fresh setup)
|
||||
4. Start Tailscale on the Windows machine and retry the login.
|
||||
|
|
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1725099143,
|
||||
"narHash": "sha256-CHgumPZaC7z+WYx72WgaLt2XF0yUVzJS60rO4GZ7ytY=",
|
||||
"lastModified": 1726238386,
|
||||
"narHash": "sha256-3//V84fYaGVncFImitM6lSAliRdrGayZLdxWlpcuGk0=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5629520edecb69630a3f4d17d3d33fc96c13f6fe",
|
||||
"rev": "01f064c99c792715054dc7a70e4c1626dbbec0c3",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
10
flake.nix
10
flake.nix
|
@ -32,7 +32,7 @@
|
|||
|
||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||
# update this if you have a mismatch after doing a change to thos files.
|
||||
vendorHash = "sha256-+8dOxPG/Q+wuHgRwwWqdphHOuop0W9dVyClyQuh7aRc=";
|
||||
vendorHash = "sha256-jDTOlRzVorbVT8fXwRedU/O2nf15cKABFrdxoTA+1wk=";
|
||||
|
||||
subPackages = ["cmd/headscale"];
|
||||
|
||||
|
@ -57,9 +57,11 @@
|
|||
subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"];
|
||||
};
|
||||
|
||||
golangci-lint = prev.golangci-lint.override {
|
||||
buildGoModule = buildGo;
|
||||
};
|
||||
# Upstream does not override buildGoModule properly,
|
||||
# importing a specific module, so comment out for now.
|
||||
# golangci-lint = prev.golangci-lint.override {
|
||||
# buildGoModule = buildGo;
|
||||
# };
|
||||
|
||||
goreleaser = prev.goreleaser.override {
|
||||
buildGoModule = buildGo;
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
|
@ -41,7 +40,6 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -95,11 +93,10 @@ type Headscale struct {
|
|||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
|
||||
registrationCache *cache.Cache
|
||||
|
||||
authProvider AuthProvider
|
||||
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
}
|
||||
|
||||
|
@ -154,16 +151,31 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
}
|
||||
})
|
||||
|
||||
var authProvider AuthProvider
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = app.initOIDC()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
cfg.ServerURL,
|
||||
&cfg.OIDC,
|
||||
app.db,
|
||||
app.registrationCache,
|
||||
app.nodeNotifier,
|
||||
app.ipAlloc,
|
||||
)
|
||||
if err != nil {
|
||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||
return nil, err
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication")
|
||||
}
|
||||
} else {
|
||||
authProvider = oidcProvider
|
||||
}
|
||||
}
|
||||
app.authProvider = authProvider
|
||||
|
||||
if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
// TODO(kradalby): revisit why this takes a list.
|
||||
|
@ -429,16 +441,15 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
|
||||
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
|
||||
|
||||
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
|
||||
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
router.HandleFunc("/oidc/callback", provider.OIDCCallback).Methods(http.MethodGet)
|
||||
}
|
||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||
Methods(http.MethodGet)
|
||||
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
|
||||
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).
|
||||
Methods(http.MethodGet)
|
||||
|
||||
// TODO(kristoffer): move swagger into a package
|
||||
router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet)
|
||||
|
@ -772,7 +783,7 @@ func (h *Headscale) Serve() error {
|
|||
})
|
||||
}
|
||||
default:
|
||||
trace := log.Trace().Msgf
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
log.Info().
|
||||
Str("signal", sig.String()).
|
||||
Msg("Received signal to stop, shutting down gracefully")
|
||||
|
@ -780,55 +791,55 @@ func (h *Headscale) Serve() error {
|
|||
expireNodeCancel()
|
||||
h.ephemeralGC.Close()
|
||||
|
||||
trace("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
|
||||
// Gracefully shut down servers
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
types.HTTPShutdownTimeout,
|
||||
)
|
||||
trace("shutting down debug http server")
|
||||
info("shutting down debug http server")
|
||||
if err := debugHTTPServer.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
|
||||
log.Error().Err(err).Msg("failed to shutdown prometheus http")
|
||||
}
|
||||
trace("shutting down main http server")
|
||||
info("shutting down main http server")
|
||||
if err := httpServer.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown http")
|
||||
log.Error().Err(err).Msg("failed to shutdown http")
|
||||
}
|
||||
|
||||
trace("shutting down grpc server (socket)")
|
||||
info("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
info("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
|
||||
info("shutting down grpc server (socket)")
|
||||
grpcSocket.GracefulStop()
|
||||
|
||||
if grpcServer != nil {
|
||||
trace("shutting down grpc server (external)")
|
||||
info("shutting down grpc server (external)")
|
||||
grpcServer.GracefulStop()
|
||||
grpcListener.Close()
|
||||
}
|
||||
|
||||
if tailsqlContext != nil {
|
||||
trace("shutting down tailsql")
|
||||
info("shutting down tailsql")
|
||||
tailsqlContext.Done()
|
||||
}
|
||||
|
||||
trace("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
// Close network listeners
|
||||
trace("closing network listeners")
|
||||
info("closing network listeners")
|
||||
debugHTTPListener.Close()
|
||||
httpListener.Close()
|
||||
grpcGatewayConn.Close()
|
||||
|
||||
// Stop listening (and unlink the socket if unix type):
|
||||
trace("closing socket listener")
|
||||
info("closing socket listener")
|
||||
socketListener.Close()
|
||||
|
||||
// Close db connections
|
||||
trace("closing database connection")
|
||||
info("closing database connection")
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to close db")
|
||||
log.Error().Err(err).Msg("failed to close db")
|
||||
}
|
||||
|
||||
log.Info().
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
|
@ -19,6 +18,11 @@ import (
|
|||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
RegisterHandler(http.ResponseWriter, *http.Request)
|
||||
AuthURL(key.MachinePublic) string
|
||||
}
|
||||
|
||||
func logAuthFunc(
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
|
@ -66,7 +70,7 @@ func (h *Headscale) handleRegister(
|
|||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey)
|
||||
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
|
||||
now := time.Now().UTC()
|
||||
logTrace("handleRegister called, looking up machine in DB")
|
||||
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
|
||||
|
@ -105,16 +109,6 @@ func (h *Headscale) handleRegister(
|
|||
|
||||
logInfo("Node not found in database, creating new")
|
||||
|
||||
givenName, err := h.db.GenerateGivenName(
|
||||
machineKey,
|
||||
regReq.Hostinfo.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
logErr(err, "Failed to generate given name for node")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// The node did not have a key to authenticate, which means
|
||||
// that we rely on a method that calls back some how (OpenID or CLI)
|
||||
// We create the node and then keep it around until a callback
|
||||
|
@ -122,7 +116,6 @@ func (h *Headscale) handleRegister(
|
|||
newNode := types.Node{
|
||||
MachineKey: machineKey,
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
NodeKey: regReq.NodeKey,
|
||||
LastSeen: &now,
|
||||
Expiry: &time.Time{},
|
||||
|
@ -175,7 +168,7 @@ func (h *Headscale) handleRegister(
|
|||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !regReq.Expiry.IsZero() &&
|
||||
regReq.Expiry.UTC().Before(now) {
|
||||
h.handleNodeLogOut(writer, *node, machineKey)
|
||||
h.handleNodeLogOut(writer, *node)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -183,7 +176,7 @@ func (h *Headscale) handleRegister(
|
|||
// If node is not expired, and it is register, we have a already accepted this node,
|
||||
// let it proceed with a valid registration
|
||||
if !node.IsExpired() {
|
||||
h.handleNodeWithValidRegistration(writer, *node, machineKey)
|
||||
h.handleNodeWithValidRegistration(writer, *node)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -196,7 +189,6 @@ func (h *Headscale) handleRegister(
|
|||
writer,
|
||||
regReq,
|
||||
*node,
|
||||
machineKey,
|
||||
)
|
||||
|
||||
return
|
||||
|
@ -209,7 +201,6 @@ func (h *Headscale) handleRegister(
|
|||
writer,
|
||||
regReq,
|
||||
*node,
|
||||
machineKey,
|
||||
)
|
||||
|
||||
return
|
||||
|
@ -354,21 +345,8 @@ func (h *Headscale) handleAuthKey(
|
|||
} else {
|
||||
now := time.Now().UTC()
|
||||
|
||||
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RegistrationHandler").
|
||||
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to generate given name for node")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
UserID: pak.User.ID,
|
||||
User: pak.User,
|
||||
MachineKey: machineKey,
|
||||
|
@ -410,7 +388,7 @@ func (h *Headscale) handleAuthKey(
|
|||
}
|
||||
}
|
||||
|
||||
h.db.Write(func(tx *gorm.DB) error {
|
||||
err = h.db.Write(func(tx *gorm.DB) error {
|
||||
return db.UsePreAuthKey(tx, pak)
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -471,17 +449,7 @@ func (h *Headscale) handleNewNode(
|
|||
// The node registration is new, redirect the client to the registration URL
|
||||
logTrace("The node seems to be new, sending auth url")
|
||||
|
||||
if h.oauth2Config != nil {
|
||||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.String(),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.String())
|
||||
}
|
||||
resp.AuthURL = h.authProvider.AuthURL(machineKey)
|
||||
|
||||
respBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
|
@ -504,7 +472,6 @@ func (h *Headscale) handleNewNode(
|
|||
func (h *Headscale) handleNodeLogOut(
|
||||
writer http.ResponseWriter,
|
||||
node types.Node,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
|
@ -587,7 +554,6 @@ func (h *Headscale) handleNodeLogOut(
|
|||
func (h *Headscale) handleNodeWithValidRegistration(
|
||||
writer http.ResponseWriter,
|
||||
node types.Node,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
|
@ -633,7 +599,6 @@ func (h *Headscale) handleNodeKeyRefresh(
|
|||
writer http.ResponseWriter,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
node types.Node,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
|
@ -709,15 +674,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
|
|||
Str("node_key_old", regReq.OldNodeKey.ShortString()).
|
||||
Msg("Node registration has expired or logged out. Sending a auth url to register")
|
||||
|
||||
if h.oauth2Config != nil {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.String())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.String())
|
||||
}
|
||||
resp.AuthURL = h.authProvider.AuthURL(machineKey)
|
||||
|
||||
respBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
|
|
|
@ -256,9 +256,6 @@ func NewHeadscaleDatabase(
|
|||
|
||||
for item, node := range nodes {
|
||||
if node.GivenName == "" {
|
||||
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
||||
node.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -268,7 +265,7 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
err = tx.Model(nodes[item]).Updates(types.Node{
|
||||
GivenName: normalizedHostname,
|
||||
GivenName: node.Hostname,
|
||||
}).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
@ -413,6 +410,18 @@ func NewHeadscaleDatabase(
|
|||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
{
|
||||
ID: "202407191627",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
err := tx.AutoMigrate(&types.User{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
|||
})
|
||||
}
|
||||
|
||||
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Preload("Routes").
|
||||
Where("given_name = ?", givenName).Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return getNode(rx, user, name)
|
||||
|
@ -242,9 +228,9 @@ func SetTags(
|
|||
}
|
||||
|
||||
// RenameNode takes a Node struct and a new GivenName for the nodes
|
||||
// and renames it.
|
||||
// and renames it. If the name is not unique, it will return an error.
|
||||
func RenameNode(tx *gorm.DB,
|
||||
nodeID uint64, newName string,
|
||||
nodeID types.NodeID, newName string,
|
||||
) error {
|
||||
err := util.CheckForFQDNRules(
|
||||
newName,
|
||||
|
@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB,
|
|||
return fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
uniq, err := isUnqiueName(tx, newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if name is unique: %w", err)
|
||||
}
|
||||
|
||||
if !uniq {
|
||||
return fmt.Errorf("name is not unique: %s", newName)
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename node in the database: %w", err)
|
||||
}
|
||||
|
@ -337,7 +332,7 @@ func RegisterNodeFromAuthCallback(
|
|||
|
||||
if nodeInterface, ok := cache.Get(mkey.String()); ok {
|
||||
if registrationNode, ok := nodeInterface.(types.Node); ok {
|
||||
user, err := GetUser(tx, userName)
|
||||
user, err := GetUserByName(tx, userName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register node from auth callback, %w",
|
||||
|
@ -390,7 +385,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Name).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Registering node")
|
||||
|
||||
// If the node exists and it already has IP(s), we just save it
|
||||
|
@ -406,7 +401,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||
Str("node", node.Hostname).
|
||||
Str("machine_key", node.MachineKey.ShortString()).
|
||||
Str("node_key", node.NodeKey.ShortString()).
|
||||
Str("user", node.User.Name).
|
||||
Str("user", node.User.Username()).
|
||||
Msg("Node authorized again")
|
||||
|
||||
return &node, nil
|
||||
|
@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||
node.IPv4 = ipv4
|
||||
node.IPv6 = ipv6
|
||||
|
||||
if node.GivenName == "" {
|
||||
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
|
||||
}
|
||||
|
||||
node.GivenName = givenName
|
||||
}
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
|
||||
}
|
||||
|
@ -617,18 +621,15 @@ func enableRoutes(tx *gorm.DB,
|
|||
}
|
||||
|
||||
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
||||
suppliedName,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
if len(suppliedName) > util.LabelHostnameLength {
|
||||
return "", types.ErrHostnameTooLong
|
||||
}
|
||||
|
||||
if randomSuffix {
|
||||
// Trim if a hostname will be longer than 63 chars after adding the hash.
|
||||
trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize
|
||||
if len(normalizedHostname) > trimmedHostnameLength {
|
||||
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
||||
if len(suppliedName) > trimmedHostnameLength {
|
||||
suppliedName = suppliedName[:trimmedHostnameLength]
|
||||
}
|
||||
|
||||
suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength)
|
||||
|
@ -636,46 +637,38 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
normalizedHostname += "-" + suffix
|
||||
suppliedName += "-" + suffix
|
||||
}
|
||||
|
||||
return normalizedHostname, nil
|
||||
return suppliedName, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GenerateGivenName(
|
||||
mkey key.MachinePublic,
|
||||
suppliedName string,
|
||||
) (string, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) {
|
||||
return GenerateGivenName(rx, mkey, suppliedName)
|
||||
})
|
||||
func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return len(nodes) == 0, nil
|
||||
}
|
||||
|
||||
func GenerateGivenName(
|
||||
func ensureUniqueGivenName(
|
||||
tx *gorm.DB,
|
||||
mkey key.MachinePublic,
|
||||
suppliedName string,
|
||||
name string,
|
||||
) (string, error) {
|
||||
givenName, err := generateGivenName(suppliedName, false)
|
||||
givenName, err := generateGivenName(name, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||
nodes, err := listNodesByGivenName(tx, givenName)
|
||||
unique, err := isUnqiueName(tx, givenName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var nodeFound *types.Node
|
||||
for idx, node := range nodes {
|
||||
if node.GivenName == givenName {
|
||||
nodeFound = nodes[idx]
|
||||
}
|
||||
}
|
||||
|
||||
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
|
||||
postfixedName, err := generateGivenName(suppliedName, true)
|
||||
if !unique {
|
||||
postfixedName, err := generateGivenName(name, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
|
@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
|||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
user1, err := db.CreateUser("user-1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode("user-1", "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machineKey2 := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "hostname-1",
|
||||
GivenName: "hostname-1",
|
||||
UserID: user1.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
|
||||
trx := db.DB.Save(node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
|
||||
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
|
||||
comment = check.Commentf("Same user, same node, same hostname, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
|
||||
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
|
||||
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
|
||||
}
|
||||
|
||||
func TestRenameNode(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
if err != nil {
|
||||
t.Fatalf("creating db: %s", err)
|
||||
}
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
user2, err := db.CreateUser("test2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
|
||||
node2 := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test",
|
||||
UserID: user2.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
|
||||
err = db.DB.Save(&node).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = db.DB.Save(&node2).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, node, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNode(tx, node2, nil, nil)
|
||||
return err
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := db.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, nodes, 2)
|
||||
|
||||
t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName)
|
||||
t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName)
|
||||
|
||||
assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName)
|
||||
assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName)
|
||||
assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname)
|
||||
assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName)
|
||||
assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname)
|
||||
assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname)
|
||||
assert.Len(t, nodes[0].Hostname, 4)
|
||||
assert.Len(t, nodes[1].Hostname, 4)
|
||||
assert.Len(t, nodes[0].GivenName, 4)
|
||||
assert.Len(t, nodes[1].GivenName, 13)
|
||||
|
||||
// Nodes can be renamed to a unique name
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "newname")
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err = db.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, nodes[0].Hostname, "test")
|
||||
assert.Equal(t, nodes[0].GivenName, "newname")
|
||||
|
||||
// Nodes can reuse name that is no longer used
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[1].ID, "test")
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err = db.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, nodes[0].Hostname, "test")
|
||||
assert.Equal(t, nodes[0].GivenName, "newname")
|
||||
assert.Equal(t, nodes[1].GivenName, "test")
|
||||
|
||||
// Nodes cannot be renamed to used names
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "test")
|
||||
})
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ var (
|
|||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
// TODO(kradalby): Should be ID, not name
|
||||
userName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
|
@ -36,13 +37,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
|||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func CreatePreAuthKey(
|
||||
tx *gorm.DB,
|
||||
// TODO(kradalby): Should be ID, not name
|
||||
userName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
user, err := GetUserByName(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -104,7 +106,7 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er
|
|||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := GetUser(tx, userName)
|
||||
user, err := GetUserByName(tx, userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -644,7 +644,7 @@ func EnableAutoApprovedRoutes(
|
|||
Msg("looking up route for autoapproving")
|
||||
|
||||
for _, approvedAlias := range routeApprovers {
|
||||
if approvedAlias == node.User.Name {
|
||||
if approvedAlias == node.User.Username() {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
|
|
|
@ -49,7 +49,7 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
|
|||
// DestroyUser destroys a User. Returns error if the User does
|
||||
// not exist or if there are nodes associated with it.
|
||||
func DestroyUser(tx *gorm.DB, name string) error {
|
||||
user, err := GetUser(tx, name)
|
||||
user, err := GetUserByName(tx, name)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
|||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||
var err error
|
||||
oldUser, err := GetUser(tx, oldName)
|
||||
oldUser, err := GetUserByName(tx, oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -98,7 +98,7 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = GetUser(tx, newName)
|
||||
_, err = GetUserByName(tx, newName)
|
||||
if err == nil {
|
||||
return ErrUserExists
|
||||
}
|
||||
|
@ -115,13 +115,13 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUser(rx, name)
|
||||
return GetUserByName(rx, name)
|
||||
})
|
||||
}
|
||||
|
||||
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||
func GetUserByName(tx *gorm.DB, name string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
|
@ -133,6 +133,24 @@ func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
|||
return &user, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByOIDCIdentifier(rx, id)
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := tx.First(&user, "provider_identifier = ?", id); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx)
|
||||
|
@ -155,7 +173,7 @@ func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := GetUser(tx, name)
|
||||
user, err := GetUserByName(tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -180,7 +198,7 @@ func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := GetUser(tx, username)
|
||||
user, err := GetUserByName(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
|||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
_, err = db.GetUserByName("test")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
|
@ -73,10 +73,10 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
|||
err = db.RenameUser("test", "test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
_, err = db.GetUserByName("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
_, err = db.GetUser("test-renamed")
|
||||
_, err = db.GetUserByName("test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.RenameUser("test-does-not-exit", "test")
|
||||
|
|
|
@ -41,7 +41,7 @@ func (api headscaleV1APIServer) GetUser(
|
|||
ctx context.Context,
|
||||
request *v1.GetUserRequest,
|
||||
) (*v1.GetUserResponse, error) {
|
||||
user, err := api.h.db.GetUser(request.GetName())
|
||||
user, err := api.h.db.GetUserByName(request.GetName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (api headscaleV1APIServer) RenameUser(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
user, err := api.h.db.GetUser(request.GetNewName())
|
||||
user, err := api.h.db.GetUserByName(request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode(
|
|||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.RenameNode(
|
||||
tx,
|
||||
request.GetNodeId(),
|
||||
types.NodeID(request.GetNodeId()),
|
||||
request.GetNewName(),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy(
|
|||
case types.PolicyModeDB:
|
||||
p, err := api.h.db.GetPolicy()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("loading ACL from database: %w", err)
|
||||
}
|
||||
|
||||
return &v1.GetPolicyResponse{
|
||||
|
@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy(
|
|||
absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
|
||||
f, err := os.Open(absPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
b, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("reading policy from file: %w", err)
|
||||
}
|
||||
|
||||
return &v1.GetPolicyResponse{Policy: string(b)}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) SetPolicy(
|
||||
|
@ -774,7 +774,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||
ctx context.Context,
|
||||
request *v1.DebugCreateNodeRequest,
|
||||
) (*v1.DebugCreateNodeResponse, error) {
|
||||
user, err := api.h.db.GetUser(request.GetUser())
|
||||
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
newNode := types.Node{
|
||||
MachineKey: mkey,
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: request.GetName(),
|
||||
GivenName: givenName,
|
||||
User: *user,
|
||||
|
||||
Expiry: &time.Time{},
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"html/template"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
@ -167,12 +168,29 @@ var registerWebAPITemplate = template.Must(
|
|||
</html>
|
||||
`))
|
||||
|
||||
type AuthProviderWeb struct {
|
||||
serverURL string
|
||||
}
|
||||
|
||||
func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
|
||||
return &AuthProviderWeb{
|
||||
serverURL: serverURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
mKey.String())
|
||||
}
|
||||
|
||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register/:nkey.
|
||||
//
|
||||
// This is not part of the Tailscale control API, as we could send whatever URL
|
||||
// in the RegisterResponse.AuthURL field.
|
||||
func (h *Headscale) RegisterWebAPI(
|
||||
func (a *AuthProviderWeb) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
|
@ -187,7 +205,7 @@ func (h *Headscale) RegisterWebAPI(
|
|||
[]byte(machineKeyStr),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
|
||||
log.Warn().Err(err).Msg("Failed to parse incoming machinekey")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
|
@ -95,10 +94,10 @@ func generateUserProfiles(
|
|||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[string]types.User)
|
||||
userMap[node.User.Name] = node.User
|
||||
userMap := make(map[uint]types.User)
|
||||
userMap[node.User.ID] = node.User
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||
userMap[peer.User.ID] = peer.User // not worth checking if already is there
|
||||
}
|
||||
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
@ -122,32 +121,6 @@ func generateDNSConfig(
|
|||
|
||||
dnsConfig := cfg.DNSConfig.Clone()
|
||||
|
||||
// if MagicDNS is enabled
|
||||
if dnsConfig.Proxied {
|
||||
if cfg.DNSUserNameInMagicDNS {
|
||||
// Only inject the Search Domain of the current user
|
||||
// shared nodes should use their full FQDN
|
||||
dnsConfig.Domains = append(
|
||||
dnsConfig.Domains,
|
||||
fmt.Sprintf(
|
||||
"%s.%s",
|
||||
node.User.Name,
|
||||
baseDomain,
|
||||
),
|
||||
)
|
||||
|
||||
userSet := mapset.NewSet[types.User]()
|
||||
userSet.Add(node.User)
|
||||
for _, p := range peers {
|
||||
userSet.Add(p.User)
|
||||
}
|
||||
for _, user := range userSet.ToSlice() {
|
||||
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
|
||||
dnsConfig.Routes[dnsRoute] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
addNextDNSMetadata(dnsConfig.Resolvers, node)
|
||||
|
||||
return dnsConfig
|
||||
|
@ -227,7 +200,7 @@ func (m *Mapper) FullMapResponse(
|
|||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
// ReadOnlyResponse returns a MapResponse for the given node.
|
||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
||||
// Lite means that the peers has been omitted, this is intended
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
|
@ -552,7 +525,7 @@ func appendPeerChanges(
|
|||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access eachother at all and remove them from the peers.
|
||||
// access each-other at all and remove them from the peers.
|
||||
if len(packetFilter) > 0 {
|
||||
changed = policy.FilterNodesByACL(node, changed, packetFilter)
|
||||
}
|
||||
|
@ -596,7 +569,7 @@ func appendPeerChanges(
|
|||
} else {
|
||||
// This is a hack to avoid sending an empty list of packet filters.
|
||||
// Since tailcfg.PacketFilter has omitempty, any empty PacketFilter will
|
||||
// be omitted, causing the client to consider it unchange, keeping the
|
||||
// be omitted, causing the client to consider it unchanged, keeping the
|
||||
// previous packet filter. Worst case, this can cause a node that previously
|
||||
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
||||
reduced := policy.ReduceFilterRules(node, packetFilter)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -28,6 +29,9 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
|||
Hostname: hostname,
|
||||
UserID: userid,
|
||||
User: types.User{
|
||||
Model: gorm.Model{
|
||||
ID: userid,
|
||||
},
|
||||
Name: username,
|
||||
},
|
||||
}
|
||||
|
@ -72,14 +76,9 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||
{
|
||||
magicDNS: true,
|
||||
want: &tailcfg.DNSConfig{
|
||||
Routes: map[string][]*dnstype.Resolver{
|
||||
"shared1.foobar.headscale.net": {},
|
||||
"shared2.foobar.headscale.net": {},
|
||||
"shared3.foobar.headscale.net": {},
|
||||
},
|
||||
Routes: map[string][]*dnstype.Resolver{},
|
||||
Domains: []string{
|
||||
"foobar.headscale.net",
|
||||
"shared1.foobar.headscale.net",
|
||||
},
|
||||
Proxied: true,
|
||||
},
|
||||
|
@ -127,8 +126,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||
|
||||
got := generateDNSConfig(
|
||||
&types.Config{
|
||||
DNSConfig: &dnsConfigOrig,
|
||||
DNSUserNameInMagicDNS: true,
|
||||
DNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
baseDomain,
|
||||
nodeInShared1,
|
||||
|
|
|
@ -76,7 +76,7 @@ func tailNode(
|
|||
keyExpiry = time.Time{}
|
||||
}
|
||||
|
||||
hostname, err := node.GetFQDN(cfg, cfg.BaseDomain)
|
||||
hostname, err := node.GetFQDN(cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ type Notifier struct {
|
|||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
cfg *types.Config
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
|
@ -43,6 +44,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
|
|||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
cfg: cfg,
|
||||
closed: false,
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
|
@ -51,9 +53,19 @@ func NewNotifier(cfg *types.Config) *Notifier {
|
|||
return n
|
||||
}
|
||||
|
||||
// Close stops the batcher inside the notifier.
|
||||
// Close stops the batcher and closes all channels.
|
||||
func (n *Notifier) Close() {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
||||
|
||||
n.closed = true
|
||||
n.b.close()
|
||||
|
||||
for _, c := range n.nodes {
|
||||
close(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
||||
|
@ -70,6 +82,10 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|||
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// If a channel exists, it means the node has opened a new
|
||||
// connection. Close the old channel and replace it.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
|
@ -96,6 +112,10 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) b
|
|||
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
||||
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(n.nodes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
@ -154,6 +174,10 @@ func (n *Notifier) NotifyWithIgnore(
|
|||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
@ -170,6 +194,10 @@ func (n *Notifier) NotifyByNodeID(
|
|||
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
||||
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if c, ok := n.nodes[nodeID]; ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -205,6 +233,10 @@ func (n *Notifier) sendAll(update types.StateUpdate) {
|
|||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for id, c := range n.nodes {
|
||||
// Whenever an update is sent to all nodes, there is a chance that the node
|
||||
// has disconnected and the goroutine that was supposed to consume the update
|
||||
|
|
|
@ -17,8 +17,10 @@ import (
|
|||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
|
@ -45,49 +47,77 @@ var (
|
|||
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username,omitempty"`
|
||||
type AuthProviderOIDC struct {
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
db *db.HSDatabase
|
||||
registrationCache *cache.Cache
|
||||
notifier *notifier.Notifier
|
||||
ipAlloc *db.IPAllocator
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
}
|
||||
|
||||
func (h *Headscale) initOIDC() error {
|
||||
func NewAuthProviderOIDC(
|
||||
ctx context.Context,
|
||||
serverURL string,
|
||||
cfg *types.OIDCConfig,
|
||||
db *db.HSDatabase,
|
||||
registrationCache *cache.Cache,
|
||||
notif *notifier.Notifier,
|
||||
ipAlloc *db.IPAllocator,
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
if h.oauth2Config == nil {
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||
}
|
||||
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDC.ClientID,
|
||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf(
|
||||
"%s/oidc/callback",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
),
|
||||
Scopes: h.cfg.OIDC.Scope,
|
||||
}
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf(
|
||||
"%s/oidc/callback",
|
||||
strings.TrimSuffix(serverURL, "/"),
|
||||
),
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
registrationCache: registrationCache,
|
||||
notifier: notif,
|
||||
ipAlloc: ipAlloc,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
|
||||
if h.cfg.OIDC.UseExpiryFromToken {
|
||||
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
mKey.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
|
||||
if a.cfg.UseExpiryFromToken {
|
||||
return idTokenExpiration
|
||||
}
|
||||
|
||||
return time.Now().Add(h.cfg.OIDC.Expiry)
|
||||
return time.Now().Add(a.cfg.Expiry)
|
||||
}
|
||||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey.
|
||||
func (h *Headscale) RegisterOIDC(
|
||||
// Listens in /register/:mKey.
|
||||
func (a *AuthProviderOIDC) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
|
@ -108,46 +138,33 @@ func (h *Headscale) RegisterOIDC(
|
|||
[]byte(machineKeyStr),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Msg("Failed to parse incoming nodekey in OIDC registration")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("Wrong params"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
if _, err := rand.Read(randomBlob); err != nil {
|
||||
util.LogErr(err, "could not read 16 bytes from rand")
|
||||
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(
|
||||
a.registrationCache.Set(
|
||||
stateStr,
|
||||
machineKey,
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
|
||||
|
||||
for k, v := range h.cfg.OIDC.ExtraParams {
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
||||
http.Redirect(writer, req, authURL, http.StatusFound)
|
||||
|
@ -170,79 +187,78 @@ var oidcCallbackTemplate = template.Must(
|
|||
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
|
||||
// TODO: Add groups information from OIDC tokens into node HostInfo
|
||||
// Listens in /oidc/callback.
|
||||
func (h *Headscale) OIDCCallback(
|
||||
func (a *AuthProviderOIDC) OIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
code, state, err := validateOIDCCallbackParams(writer, req)
|
||||
code, state, err := validateOIDCCallbackParams(req)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
|
||||
rawIDToken, err := a.getIDTokenForOIDCCallback(req.Context(), code)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
|
||||
idToken, err := a.verifyIDTokenForOIDCCallback(req.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
|
||||
idTokenExpiry := a.determineTokenExpiration(idToken.Expiry)
|
||||
|
||||
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
|
||||
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
|
||||
// if err != nil {
|
||||
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
|
||||
// return
|
||||
// }
|
||||
|
||||
claims, err := extractIDTokenClaims(writer, idToken)
|
||||
if err != nil {
|
||||
var claims types.OIDCClaims
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
|
||||
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
|
||||
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
|
||||
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
||||
machineKey, nodeExists, err := a.validateNodeForOIDCCallback(
|
||||
writer,
|
||||
state,
|
||||
claims,
|
||||
&claims,
|
||||
idTokenExpiry,
|
||||
)
|
||||
if err != nil || nodeExists {
|
||||
return
|
||||
}
|
||||
|
||||
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// register the node if it's new
|
||||
log.Debug().Msg("Registering new node after successful callback")
|
||||
|
||||
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
|
||||
user, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
|
||||
if err := a.registerNodeForOIDCCallback(user, machineKey, idTokenExpiry); err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := renderOIDCCallbackTemplate(writer, claims)
|
||||
content, err := renderOIDCCallbackTemplate(&claims)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -254,127 +270,57 @@ func (h *Headscale) OIDCCallback(
|
|||
}
|
||||
|
||||
func validateOIDCCallbackParams(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) (string, string, error) {
|
||||
code := req.URL.Query().Get("code")
|
||||
state := req.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("Wrong params"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return "", "", errEmptyOIDCCallbackParams
|
||||
}
|
||||
|
||||
return code, state, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getIDTokenForOIDCCallback(
|
||||
func (a *AuthProviderOIDC) getIDTokenForOIDCCallback(
|
||||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
code, state string,
|
||||
code string,
|
||||
) (string, error) {
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
||||
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
util.LogErr(err, "Could not exchange code for token")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, werr := writer.Write([]byte("Could not exchange code for token"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return "", err
|
||||
return "", fmt.Errorf("could not exchange code for token: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("code", code).
|
||||
Str("state", state).
|
||||
Msg("Got oidc callback")
|
||||
|
||||
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
|
||||
if !rawIDTokenOK {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("Could not extract ID Token"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return "", errNoOIDCIDToken
|
||||
}
|
||||
|
||||
return rawIDToken, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) verifyIDTokenForOIDCCallback(
|
||||
func (a *AuthProviderOIDC) verifyIDTokenForOIDCCallback(
|
||||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
rawIDToken string,
|
||||
) (*oidc.IDToken, error) {
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
util.LogErr(err, "failed to verify id token")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, werr := writer.Write([]byte("Failed to verify id token"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func extractIDTokenClaims(
|
||||
writer http.ResponseWriter,
|
||||
idToken *oidc.IDToken,
|
||||
) (*IDTokenClaims, error) {
|
||||
var claims IDTokenClaims
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
util.LogErr(err, "Failed to decode id token claims")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, werr := writer.Write([]byte("Failed to decode id token claims"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
|
||||
// that the authenticated principal ends with @<alloweddomain>.
|
||||
func validateOIDCAllowedDomains(
|
||||
writer http.ResponseWriter,
|
||||
allowedDomains []string,
|
||||
claims *IDTokenClaims,
|
||||
claims *types.OIDCClaims,
|
||||
) error {
|
||||
if len(allowedDomains) > 0 {
|
||||
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
|
||||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
|
||||
log.Trace().Msg("authenticated principal does not match any allowed domain")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return errOIDCAllowedDomains
|
||||
}
|
||||
}
|
||||
|
@ -387,9 +333,8 @@ func validateOIDCAllowedDomains(
|
|||
// claims.Groups can be populated by adding a client scope named
|
||||
// 'groups' that contains group membership.
|
||||
func validateOIDCAllowedGroups(
|
||||
writer http.ResponseWriter,
|
||||
allowedGroups []string,
|
||||
claims *IDTokenClaims,
|
||||
claims *types.OIDCClaims,
|
||||
) error {
|
||||
if len(allowedGroups) > 0 {
|
||||
for _, group := range allowedGroups {
|
||||
|
@ -398,14 +343,6 @@ func validateOIDCAllowedGroups(
|
|||
}
|
||||
}
|
||||
|
||||
log.Trace().Msg("authenticated principal not in any allowed groups")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return errOIDCAllowedGroups
|
||||
}
|
||||
|
||||
|
@ -415,20 +352,12 @@ func validateOIDCAllowedGroups(
|
|||
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
|
||||
// that the authenticated principal is part of that list.
|
||||
func validateOIDCAllowedUsers(
|
||||
writer http.ResponseWriter,
|
||||
allowedUsers []string,
|
||||
claims *IDTokenClaims,
|
||||
claims *types.OIDCClaims,
|
||||
) error {
|
||||
if len(allowedUsers) > 0 &&
|
||||
!slices.Contains(allowedUsers, claims.Email) {
|
||||
log.Trace().Msg("authenticated principal does not match any allowed user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return errOIDCAllowedUsers
|
||||
}
|
||||
|
||||
|
@ -439,40 +368,21 @@ func validateOIDCAllowedUsers(
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
func (h *Headscale) validateNodeForOIDCCallback(
|
||||
func (a *AuthProviderOIDC) validateNodeForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
state string,
|
||||
claims *IDTokenClaims,
|
||||
claims *types.OIDCClaims,
|
||||
expiry time.Time,
|
||||
) (*key.MachinePublic, bool, error) {
|
||||
// retrieve nodekey from state cache
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
machineKeyIf, machineKeyFound := a.registrationCache.Get(state)
|
||||
if !machineKeyFound {
|
||||
log.Trace().
|
||||
Msg("requested node state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state has expired"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
|
||||
if !machineKeyOK {
|
||||
log.Trace().
|
||||
Interface("got", machineKeyIf).
|
||||
Msg("requested node state key is not a nodekey")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCInvalidNodeState
|
||||
}
|
||||
|
||||
|
@ -480,7 +390,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := h.db.GetNodeByMachineKey(machineKey)
|
||||
node, _ := a.db.GetNodeByMachineKey(machineKey)
|
||||
|
||||
if node != nil {
|
||||
log.Trace().
|
||||
|
@ -488,20 +398,13 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
Str("node", node.Hostname).
|
||||
Msg("node already registered, reauthenticating")
|
||||
|
||||
err := h.db.NodeSetExpiry(node.ID, expiry)
|
||||
err := a.db.NodeSetExpiry(node.ID, expiry)
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to refresh node")
|
||||
http.Error(
|
||||
writer,
|
||||
"Failed to refresh node",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
|
||||
return nil, true, err
|
||||
}
|
||||
log.Debug().
|
||||
Str("node", node.Hostname).
|
||||
Str("expiresAt", fmt.Sprintf("%v", expiry)).
|
||||
Time("expiresAt", expiry).
|
||||
Msg("successfully refreshed node")
|
||||
|
||||
var content bytes.Buffer
|
||||
|
@ -509,13 +412,6 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
User: claims.Email,
|
||||
Verb: "Reauthenticated",
|
||||
}); err != nil {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, true, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||
}
|
||||
|
||||
|
@ -527,7 +423,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
h.nodeNotifier.NotifyByNodeID(
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.StateUpdate{
|
||||
Type: types.StateSelfUpdate,
|
||||
|
@ -537,7 +433,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
@ -545,79 +441,56 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
return &machineKey, false, nil
|
||||
}
|
||||
|
||||
func getUserName(
|
||||
writer http.ResponseWriter,
|
||||
claims *IDTokenClaims,
|
||||
stripEmaildomain bool,
|
||||
) (string, error) {
|
||||
userName, err := util.NormalizeToFQDNRules(
|
||||
claims.Email,
|
||||
stripEmaildomain,
|
||||
)
|
||||
if err != nil {
|
||||
util.LogErr(err, "couldn't normalize email")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("couldn't normalize email"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return "", err
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
claims *types.OIDCClaims,
|
||||
) (*types.User, error) {
|
||||
var user *types.User
|
||||
var err error
|
||||
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
return userName, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
userName string,
|
||||
) (*types.User, error) {
|
||||
user, err := h.db.GetUser(userName)
|
||||
if errors.Is(err, db.ErrUserNotFound) {
|
||||
user, err = h.db.CreateUser(userName)
|
||||
if err != nil {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("could not create user"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("creating new user: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("could not find or create user"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
||||
// look it up by username. This should only be needed once.
|
||||
if user == nil {
|
||||
user, err = a.db.GetUserByName(claims.Username)
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("find or create user: %w", err)
|
||||
// if the user is still not found, create a new empty user.
|
||||
if user == nil {
|
||||
user = &types.User{}
|
||||
}
|
||||
}
|
||||
|
||||
user.FromClaim(claims)
|
||||
err = a.db.DB.Save(user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) registerNodeForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
func (a *AuthProviderOIDC) registerNodeForOIDCCallback(
|
||||
user *types.User,
|
||||
machineKey *key.MachinePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
ipv4, ipv6, err := h.ipAlloc.Next()
|
||||
ipv4, ipv6, err := a.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||
if err := a.db.Write(func(tx *gorm.DB) error {
|
||||
if _, err := db.RegisterNodeFromAuthCallback(
|
||||
// TODO(kradalby): find a better way to use the cache across modules
|
||||
tx,
|
||||
h.registrationCache,
|
||||
a.registrationCache,
|
||||
*machineKey,
|
||||
// TODO(kradalby): Should be ID, not name
|
||||
user.Name,
|
||||
&expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
|
@ -628,36 +501,20 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
|||
|
||||
return nil
|
||||
}); err != nil {
|
||||
util.LogErr(err, "could not register node")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("could not register node"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return err
|
||||
return fmt.Errorf("could not register node: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func renderOIDCCallbackTemplate(
|
||||
writer http.ResponseWriter,
|
||||
claims *IDTokenClaims,
|
||||
claims *types.OIDCClaims,
|
||||
) (*bytes.Buffer, error) {
|
||||
var content bytes.Buffer
|
||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||
User: claims.Email,
|
||||
Verb: "Authenticated",
|
||||
}); err != nil {
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
|
||||
if werr != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -59,46 +59,6 @@ func (h *Headscale) WindowsConfigMessage(
|
|||
}
|
||||
}
|
||||
|
||||
// WindowsRegConfig generates and serves a .reg file configured with the Headscale server address.
|
||||
func (h *Headscale) WindowsRegConfig(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
config := WindowsRegistryConfig{
|
||||
URL: h.cfg.ServerURL,
|
||||
}
|
||||
|
||||
var content bytes.Buffer
|
||||
if err := windowsRegTemplate.Execute(&content, config); err != nil {
|
||||
log.Error().
|
||||
Str("handler", "WindowsRegConfig").
|
||||
Err(err).
|
||||
Msg("Could not render Apple macOS template")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("Could not render Windows registry template"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
_, err := writer.Write(content.Bytes())
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
}
|
||||
|
||||
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
||||
func (h *Headscale) AppleConfigMessage(
|
||||
writer http.ResponseWriter,
|
||||
|
@ -305,10 +265,6 @@ func (h *Headscale) ApplePlatformConfig(
|
|||
}
|
||||
}
|
||||
|
||||
type WindowsRegistryConfig struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
type AppleMobileConfig struct {
|
||||
UUID uuid.UUID
|
||||
URL string
|
||||
|
@ -320,14 +276,6 @@ type AppleMobilePlatformConfig struct {
|
|||
URL string
|
||||
}
|
||||
|
||||
var windowsRegTemplate = textTemplate.Must(
|
||||
textTemplate.New("windowsconfig").Parse(`Windows Registry Editor Version 5.00
|
||||
|
||||
[HKEY_LOCAL_MACHINE\SOFTWARE\Tailscale IPN]
|
||||
"UnattendedMode"="always"
|
||||
"LoginURL"="{{.URL}}"
|
||||
`))
|
||||
|
||||
var commonTemplate = textTemplate.Must(
|
||||
textTemplate.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
|
|
|
@ -737,15 +737,7 @@ func (pol *ACLPolicy) expandUsersFromGroup(
|
|||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf(
|
||||
"failed to normalize group %q, err: %w",
|
||||
group,
|
||||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
users = append(users, grp)
|
||||
users = append(users, group)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
|
@ -934,7 +926,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
}
|
||||
var found bool
|
||||
for _, owner := range owners {
|
||||
if node.User.Name == owner {
|
||||
if node.User.Username() == owner {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
@ -958,7 +950,7 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
||||
var out types.Nodes
|
||||
for _, node := range nodes {
|
||||
if node.User.Name == user {
|
||||
if node.User.Username() == user {
|
||||
out = append(out, node)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -341,7 +341,7 @@ func TestParsing(t *testing.T) {
|
|||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
|
@ -633,25 +633,6 @@ func Test_expandGroup(t *testing.T) {
|
|||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Expand emails in group strip domains",
|
||||
field: field{
|
||||
pol: ACLPolicy{
|
||||
Groups: Groups{
|
||||
"group:admin": []string{
|
||||
"joe.bar@gmail.com",
|
||||
"john.doe@yahoo.fr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
group: "group:admin",
|
||||
stripEmail: true,
|
||||
},
|
||||
want: []string{"joe.bar", "john.doe"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Expand emails in group",
|
||||
field: field{
|
||||
|
@ -667,7 +648,7 @@ func Test_expandGroup(t *testing.T) {
|
|||
args: args{
|
||||
group: "group:admin",
|
||||
},
|
||||
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
|
||||
want: []string{"joe.bar@gmail.com", "john.doe@yahoo.fr"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -273,6 +274,12 @@ func (m *mapSession) serveLongPoll() {
|
|||
return
|
||||
}
|
||||
|
||||
// If the node has been removed from headscale, close the stream
|
||||
if slices.Contains(update.Removed, m.node.ID) {
|
||||
m.tracef("node removed, closing stream")
|
||||
return
|
||||
}
|
||||
|
||||
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
||||
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
||||
|
||||
|
|
|
@ -46,9 +46,7 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||
Path: tmpDir + "/headscale_test.db",
|
||||
},
|
||||
},
|
||||
OIDC: types.OIDCConfig{
|
||||
StripEmaildomain: false,
|
||||
},
|
||||
OIDC: types.OIDCConfig{},
|
||||
}
|
||||
|
||||
app, err = NewHeadscale(&cfg)
|
||||
|
|
|
@ -25,17 +25,48 @@
|
|||
</head>
|
||||
|
||||
<body>
|
||||
<h1>headscale: iOS configuration</h1>
|
||||
<h2>GUI</h2>
|
||||
<ol>
|
||||
<li>
|
||||
Install the official Tailscale iOS client from the
|
||||
<a href="https://apps.apple.com/app/tailscale/id1470499037"
|
||||
>App store</a
|
||||
>
|
||||
</li>
|
||||
<li>
|
||||
Open Tailscale and make sure you are <i>not</i> logged in to any account
|
||||
</li>
|
||||
<li>Open Settings on the iOS device</li>
|
||||
<li>
|
||||
Scroll down to the "third party apps" section, under "Game Center" or
|
||||
"TV Provider"
|
||||
</li>
|
||||
<li>
|
||||
Find Tailscale and select it
|
||||
<ul>
|
||||
<li>
|
||||
If the iOS device was previously logged into Tailscale, switch the
|
||||
"Reset Keychain" toggle to "on"
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>Enter "{{.URL}}" under "Alternate Coordination Server URL"</li>
|
||||
<li>
|
||||
Restart the app by closing it from the iOS app switcher, open the app
|
||||
and select the regular sign in option <i>(non-SSO)</i>. It should open
|
||||
up to the headscale authentication page.
|
||||
</li>
|
||||
<li>
|
||||
Enter your credentials and log in. Headscale should now be working on
|
||||
your iOS device
|
||||
</li>
|
||||
</ol>
|
||||
<h1>headscale: macOS configuration</h1>
|
||||
<h2>Recent Tailscale versions (1.34.0 and higher)</h2>
|
||||
<p>
|
||||
Tailscale added Fast User Switching in version 1.34 and you can now use
|
||||
the new login command to connect to one or more headscale (and Tailscale)
|
||||
servers. The previously used profiles does not have an effect anymore.
|
||||
</p>
|
||||
<h3>Command line</h3>
|
||||
<h2>Command line</h2>
|
||||
<p>Use Tailscale's login command to add your profile:</p>
|
||||
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
|
||||
<h3>GUI</h3>
|
||||
<h2>GUI</h2>
|
||||
<ol>
|
||||
<li>
|
||||
ALT + Click the Tailscale icon in the menu and hover over the Debug menu
|
||||
|
@ -46,44 +77,7 @@
|
|||
</li>
|
||||
<li>Follow the login procedure in the browser</li>
|
||||
</ol>
|
||||
<h2>Apple configuration profiles (1.32.0 and lower)</h2>
|
||||
<p>
|
||||
This page provides
|
||||
<a href="https://support.apple.com/guide/mdm/mdm-overview-mdmbf9e668/web"
|
||||
>configuration profiles</a
|
||||
>
|
||||
for the official Tailscale clients for
|
||||
</p>
|
||||
<ul>
|
||||
<li>
|
||||
<a href="https://apps.apple.com/app/tailscale/id1475387142"
|
||||
>macOS - AppStore Client</a
|
||||
>.
|
||||
</li>
|
||||
<li>
|
||||
<a href="https://pkgs.tailscale.com/stable/#macos"
|
||||
>macOS - Standalone Client</a
|
||||
>.
|
||||
</li>
|
||||
</ul>
|
||||
<p>
|
||||
The profiles will configure Tailscale.app to use <code>{{.URL}}</code> as
|
||||
its control server.
|
||||
</p>
|
||||
<h3>Caution</h3>
|
||||
<p>
|
||||
You should always download and inspect the profile before installing it:
|
||||
</p>
|
||||
<ul>
|
||||
<li>
|
||||
for app store client: <code>curl {{.URL}}/apple/macos-app-store</code>
|
||||
</li>
|
||||
<li>
|
||||
for standalone client: <code>curl {{.URL}}/apple/macos-standalone</code>
|
||||
</li>
|
||||
</ul>
|
||||
<h2>Profiles</h2>
|
||||
<h3>macOS</h3>
|
||||
<p>
|
||||
Headscale can be set to the default server by installing a Headscale
|
||||
configuration profile:
|
||||
|
@ -121,50 +115,17 @@
|
|||
</li>
|
||||
</ul>
|
||||
<p>Restart Tailscale.app and log in.</p>
|
||||
<h1>headscale: iOS configuration</h1>
|
||||
<h2>Recent Tailscale versions (1.38.1 and higher)</h2>
|
||||
<h3>Caution</h3>
|
||||
<p>
|
||||
Tailscale 1.38.1 on
|
||||
<a href="https://apps.apple.com/app/tailscale/id1470499037">iOS</a>
|
||||
added a configuration option to allow user to set an "Alternate
|
||||
Coordination server". This can be used to connect to your headscale
|
||||
server.
|
||||
You should always download and inspect the profile before installing it:
|
||||
</p>
|
||||
<h3>GUI</h3>
|
||||
<ol>
|
||||
<ul>
|
||||
<li>
|
||||
Install the official Tailscale iOS client from the
|
||||
<a href="https://apps.apple.com/app/tailscale/id1470499037"
|
||||
>App store</a
|
||||
>
|
||||
for app store client: <code>curl {{.URL}}/apple/macos-app-store</code>
|
||||
</li>
|
||||
<li>
|
||||
Open Tailscale and make sure you are <i>not</i> logged in to any account
|
||||
for standalone client: <code>curl {{.URL}}/apple/macos-standalone</code>
|
||||
</li>
|
||||
<li>Open Settings on the iOS device</li>
|
||||
<li>
|
||||
Scroll down to the "third party apps" section, under "Game Center" or
|
||||
"TV Provider"
|
||||
</li>
|
||||
<li>
|
||||
Find Tailscale and select it
|
||||
<ul>
|
||||
<li>
|
||||
If the iOS device was previously logged into Tailscale, switch the
|
||||
"Reset Keychain" toggle to "on"
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li>Enter "{{.URL}}" under "Alternate Coordination Server URL"</li>
|
||||
<li>
|
||||
Restart the app by closing it from the iOS app switcher, open the app
|
||||
and select the regular sign in option <i>(non-SSO)</i>. It should open
|
||||
up to the headscale authentication page.
|
||||
</li>
|
||||
<li>
|
||||
Enter your credentials and log in. Headscale should now be working on
|
||||
your iOS device
|
||||
</li>
|
||||
</ol>
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
@ -25,75 +25,21 @@
|
|||
|
||||
<body>
|
||||
<h1>headscale: Windows configuration</h1>
|
||||
<h2>Recent Tailscale versions (1.34.0 and higher)</h2>
|
||||
<p>
|
||||
Tailscale added Fast User Switching in version 1.34 and you can now use
|
||||
the new login command to connect to one or more headscale (and Tailscale)
|
||||
servers. The previously used profiles does not have an effect anymore.
|
||||
</p>
|
||||
<p>Use Tailscale's login command to add your profile:</p>
|
||||
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
|
||||
|
||||
<h2>Windows registry configuration (1.32.0 and lower)</h2>
|
||||
<p>
|
||||
This page provides Windows registry information for the official Windows
|
||||
Tailscale client.
|
||||
</p>
|
||||
|
||||
<p></p>
|
||||
<p>
|
||||
The registry file will configure Tailscale to use <code>{{.URL}}</code> as
|
||||
its control server.
|
||||
</p>
|
||||
|
||||
<p></p>
|
||||
<h3>Caution</h3>
|
||||
<p>
|
||||
You should always download and inspect the registry file before installing
|
||||
it:
|
||||
</p>
|
||||
<pre><code>curl {{.URL}}/windows/tailscale.reg</code></pre>
|
||||
|
||||
<h2>Installation</h2>
|
||||
<p>
|
||||
Headscale can be set to the default server by running the registry file:
|
||||
</p>
|
||||
|
||||
<p>
|
||||
<a href="/windows/tailscale.reg" download="tailscale.reg"
|
||||
>Windows registry file</a
|
||||
Download
|
||||
<a
|
||||
href="https://tailscale.com/download/windows"
|
||||
rel="noreferrer noopener"
|
||||
target="_blank"
|
||||
>Tailscale for Windows</a
|
||||
>
|
||||
and install it.
|
||||
</p>
|
||||
|
||||
<ol>
|
||||
<li>Download the registry file, then run it</li>
|
||||
<li>Follow the prompts</li>
|
||||
<li>Install and run the official windows Tailscale client</li>
|
||||
<li>
|
||||
When the installation has finished, start Tailscale, and log in by
|
||||
clicking the icon in the system tray
|
||||
</li>
|
||||
</ol>
|
||||
<p>Or using REG:</p>
|
||||
<p>
|
||||
Open command prompt with Administrator rights. Issue the following
|
||||
commands to add the required registry entries:
|
||||
Open a Command Prompt or Powershell and use Tailscale's login command to
|
||||
connect with headscale:
|
||||
</p>
|
||||
<pre>
|
||||
<code>REG ADD "HKLM\Software\Tailscale IPN" /v UnattendedMode /t REG_SZ /d always
|
||||
REG ADD "HKLM\Software\Tailscale IPN" /v LoginURL /t REG_SZ /d "{{.URL}}"</code>
|
||||
</pre>
|
||||
<p>Or using Powershell</p>
|
||||
<p>
|
||||
Open Powershell with Administrator rights. Issue the following commands to
|
||||
add the required registry entries:
|
||||
</p>
|
||||
<pre>
|
||||
<code>New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
|
||||
New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value "{{.URL}}"</code>
|
||||
</pre>
|
||||
<p>Finally, restart Tailscale and log in.</p>
|
||||
|
||||
<p></p>
|
||||
<pre><code>tailscale login --login-server {{.URL}}</code></pre>
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
@ -71,8 +71,7 @@ type Config struct {
|
|||
ACMEURL string
|
||||
ACMEEmail string
|
||||
|
||||
DNSConfig *tailcfg.DNSConfig
|
||||
DNSUserNameInMagicDNS bool
|
||||
DNSConfig *tailcfg.DNSConfig
|
||||
|
||||
UnixSocket string
|
||||
UnixSocketPermission fs.FileMode
|
||||
|
@ -90,12 +89,11 @@ type Config struct {
|
|||
}
|
||||
|
||||
type DNSConfig struct {
|
||||
MagicDNS bool `mapstructure:"magic_dns"`
|
||||
BaseDomain string `mapstructure:"base_domain"`
|
||||
Nameservers Nameservers
|
||||
SearchDomains []string `mapstructure:"search_domains"`
|
||||
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
|
||||
UserNameInMagicDNS bool `mapstructure:"use_username_in_magic_dns"`
|
||||
MagicDNS bool `mapstructure:"magic_dns"`
|
||||
BaseDomain string `mapstructure:"base_domain"`
|
||||
Nameservers Nameservers
|
||||
SearchDomains []string `mapstructure:"search_domains"`
|
||||
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
|
||||
}
|
||||
|
||||
type Nameservers struct {
|
||||
|
@ -164,7 +162,6 @@ type OIDCConfig struct {
|
|||
AllowedDomains []string
|
||||
AllowedUsers []string
|
||||
AllowedGroups []string
|
||||
StripEmaildomain bool
|
||||
Expiry time.Duration
|
||||
UseExpiryFromToken bool
|
||||
}
|
||||
|
@ -212,6 +209,12 @@ type Tuning struct {
|
|||
NodeMapSessionBufferedChanSize int
|
||||
}
|
||||
|
||||
// LoadConfig prepares and loads the Headscale configuration into Viper.
|
||||
// This means it sets the default values, reads the configuration file and
|
||||
// environment variables, and handles deprecated configuration options.
|
||||
// It has to be called before LoadServerConfig and LoadCLIConfig.
|
||||
// The configuration is not validated and the caller should check for errors
|
||||
// using a validation function.
|
||||
func LoadConfig(path string, isFile bool) error {
|
||||
if isFile {
|
||||
viper.SetConfigFile(path)
|
||||
|
@ -268,7 +271,6 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
||||
|
||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||
viper.SetDefault("oidc.strip_email_domain", true)
|
||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||
viper.SetDefault("oidc.expiry", "180d")
|
||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||
|
@ -284,14 +286,14 @@ func LoadConfig(path string, isFile bool) error {
|
|||
|
||||
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
|
||||
|
||||
if IsCLIConfigured() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateServerConfig() error {
|
||||
depr := deprecator{
|
||||
warns: make(set.Set[string]),
|
||||
fatals: make(set.Set[string]),
|
||||
|
@ -315,8 +317,22 @@ func LoadConfig(path string, isFile bool) error {
|
|||
depr.warn("dns_config.use_username_in_magic_dns")
|
||||
depr.warn("dns.use_username_in_magic_dns")
|
||||
|
||||
depr.fatal("oidc.strip_email_domain")
|
||||
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
||||
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
||||
|
||||
depr.Log()
|
||||
|
||||
for _, removed := range []string{
|
||||
"oidc.strip_email_domain",
|
||||
"dns_config.use_username_in_musername_in_magic_dns",
|
||||
} {
|
||||
if viper.IsSet(removed) {
|
||||
log.Fatal().
|
||||
Msgf("Fatal config error: %s has been removed. Please remove it from your config file", removed)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect any validation errors and return them all at once
|
||||
var errorText string
|
||||
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
|
||||
|
@ -360,12 +376,12 @@ func LoadConfig(path string, isFile bool) error {
|
|||
if errorText != "" {
|
||||
// nolint
|
||||
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetTLSConfig() TLSConfig {
|
||||
func tlsConfig() TLSConfig {
|
||||
return TLSConfig{
|
||||
LetsEncrypt: LetsEncryptConfig{
|
||||
Hostname: viper.GetString("tls_letsencrypt_hostname"),
|
||||
|
@ -384,7 +400,7 @@ func GetTLSConfig() TLSConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetDERPConfig() DERPConfig {
|
||||
func derpConfig() DERPConfig {
|
||||
serverEnabled := viper.GetBool("derp.server.enabled")
|
||||
serverRegionID := viper.GetInt("derp.server.region_id")
|
||||
serverRegionCode := viper.GetString("derp.server.region_code")
|
||||
|
@ -445,7 +461,7 @@ func GetDERPConfig() DERPConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetLogTailConfig() LogTailConfig {
|
||||
func logtailConfig() LogTailConfig {
|
||||
enabled := viper.GetBool("logtail.enabled")
|
||||
|
||||
return LogTailConfig{
|
||||
|
@ -453,7 +469,7 @@ func GetLogTailConfig() LogTailConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetPolicyConfig() PolicyConfig {
|
||||
func policyConfig() PolicyConfig {
|
||||
policyPath := viper.GetString("policy.path")
|
||||
policyMode := viper.GetString("policy.mode")
|
||||
|
||||
|
@ -463,7 +479,7 @@ func GetPolicyConfig() PolicyConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetLogConfig() LogConfig {
|
||||
func logConfig() LogConfig {
|
||||
logLevelStr := viper.GetString("log.level")
|
||||
logLevel, err := zerolog.ParseLevel(logLevelStr)
|
||||
if err != nil {
|
||||
|
@ -473,9 +489,9 @@ func GetLogConfig() LogConfig {
|
|||
logFormatOpt := viper.GetString("log.format")
|
||||
var logFormat string
|
||||
switch logFormatOpt {
|
||||
case "json":
|
||||
case JSONLogFormat:
|
||||
logFormat = JSONLogFormat
|
||||
case "text":
|
||||
case TextLogFormat:
|
||||
logFormat = TextLogFormat
|
||||
case "":
|
||||
logFormat = TextLogFormat
|
||||
|
@ -491,7 +507,7 @@ func GetLogConfig() LogConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func GetDatabaseConfig() DatabaseConfig {
|
||||
func databaseConfig() DatabaseConfig {
|
||||
debug := viper.GetBool("database.debug")
|
||||
|
||||
type_ := viper.GetString("database.type")
|
||||
|
@ -543,7 +559,7 @@ func GetDatabaseConfig() DatabaseConfig {
|
|||
}
|
||||
}
|
||||
|
||||
func DNS() (DNSConfig, error) {
|
||||
func dns() (DNSConfig, error) {
|
||||
var dns DNSConfig
|
||||
|
||||
// TODO: Use this instead of manually getting settings when
|
||||
|
@ -566,21 +582,18 @@ func DNS() (DNSConfig, error) {
|
|||
if err != nil {
|
||||
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
|
||||
}
|
||||
|
||||
dns.ExtraRecords = extraRecords
|
||||
}
|
||||
|
||||
dns.UserNameInMagicDNS = viper.GetBool("dns.use_username_in_magic_dns")
|
||||
|
||||
return dns, nil
|
||||
}
|
||||
|
||||
// GlobalResolvers returns the global DNS resolvers
|
||||
// globalResolvers returns the global DNS resolvers
|
||||
// defined in the config file.
|
||||
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
||||
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
||||
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
||||
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
|
||||
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
||||
for _, nsStr := range d.Nameservers.Global {
|
||||
|
@ -613,11 +626,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
|
|||
return resolvers
|
||||
}
|
||||
|
||||
// SplitResolvers returns a map of domain to DNS resolvers.
|
||||
// splitResolvers returns a map of domain to DNS resolvers.
|
||||
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
||||
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
||||
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
||||
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
|
||||
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
|
||||
routes := make(map[string][]*dnstype.Resolver)
|
||||
for domain, nameservers := range d.Nameservers.Split {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
@ -653,7 +666,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
|
|||
return routes
|
||||
}
|
||||
|
||||
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
cfg := tailcfg.DNSConfig{}
|
||||
|
||||
if dns.BaseDomain == "" && dns.MagicDNS {
|
||||
|
@ -662,9 +675,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
|||
|
||||
cfg.Proxied = dns.MagicDNS
|
||||
cfg.ExtraRecords = dns.ExtraRecords
|
||||
cfg.Resolvers = dns.GlobalResolvers()
|
||||
cfg.Resolvers = dns.globalResolvers()
|
||||
|
||||
routes := dns.SplitResolvers()
|
||||
routes := dns.splitResolvers()
|
||||
cfg.Routes = routes
|
||||
if dns.BaseDomain != "" {
|
||||
cfg.Domains = []string{dns.BaseDomain}
|
||||
|
@ -674,7 +687,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
|||
return &cfg
|
||||
}
|
||||
|
||||
func PrefixV4() (*netip.Prefix, error) {
|
||||
func prefixV4() (*netip.Prefix, error) {
|
||||
prefixV4Str := viper.GetString("prefixes.v4")
|
||||
|
||||
if prefixV4Str == "" {
|
||||
|
@ -698,7 +711,7 @@ func PrefixV4() (*netip.Prefix, error) {
|
|||
return &prefixV4, nil
|
||||
}
|
||||
|
||||
func PrefixV6() (*netip.Prefix, error) {
|
||||
func prefixV6() (*netip.Prefix, error) {
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
if prefixV6Str == "" {
|
||||
|
@ -723,27 +736,41 @@ func PrefixV6() (*netip.Prefix, error) {
|
|||
return &prefixV6, nil
|
||||
}
|
||||
|
||||
func GetHeadscaleConfig() (*Config, error) {
|
||||
if IsCLIConfigured() {
|
||||
return &Config{
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
APIKey: viper.GetString("cli.api_key"),
|
||||
Timeout: viper.GetDuration("cli.timeout"),
|
||||
Insecure: viper.GetBool("cli.insecure"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
logConfig := GetLogConfig()
|
||||
// LoadCLIConfig returns the needed configuration for the CLI client
|
||||
// of Headscale to connect to a Headscale server.
|
||||
func LoadCLIConfig() (*Config, error) {
|
||||
logConfig := logConfig()
|
||||
zerolog.SetGlobalLevel(logConfig.Level)
|
||||
|
||||
prefix4, err := PrefixV4()
|
||||
return &Config{
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
UnixSocket: viper.GetString("unix_socket"),
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
APIKey: viper.GetString("cli.api_key"),
|
||||
Timeout: viper.GetDuration("cli.timeout"),
|
||||
Insecure: viper.GetBool("cli.insecure"),
|
||||
},
|
||||
Log: logConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadServerConfig returns the full Headscale configuration to
|
||||
// host a Headscale server. This is called as part of `headscale serve`.
|
||||
func LoadServerConfig() (*Config, error) {
|
||||
if err := validateServerConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logConfig := logConfig()
|
||||
zerolog.SetGlobalLevel(logConfig.Level)
|
||||
|
||||
prefix4, err := prefixV4()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix6, err := PrefixV6()
|
||||
prefix6, err := prefixV6()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -760,16 +787,21 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
case string(IPAllocationStrategyRandom):
|
||||
alloc = IPAllocationStrategyRandom
|
||||
default:
|
||||
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
|
||||
return nil, fmt.Errorf(
|
||||
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
|
||||
allocStr,
|
||||
IPAllocationStrategySequential,
|
||||
IPAllocationStrategyRandom,
|
||||
)
|
||||
}
|
||||
|
||||
dnsConfig, err := DNS()
|
||||
dnsConfig, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
derpConfig := GetDERPConfig()
|
||||
logTailConfig := GetLogTailConfig()
|
||||
derpConfig := derpConfig()
|
||||
logTailConfig := logtailConfig()
|
||||
randomizeClientPort := viper.GetBool("randomize_client_port")
|
||||
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
|
@ -794,10 +826,11 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
// - DERP run on their own domains
|
||||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||
//
|
||||
// TODO(kradalby): remove dnsConfig.UserNameInMagicDNS check when removed.
|
||||
if !dnsConfig.UserNameInMagicDNS && dnsConfig.BaseDomain != "" && strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
||||
return nil, errors.New("server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.")
|
||||
if dnsConfig.BaseDomain != "" &&
|
||||
strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
||||
return nil, errors.New(
|
||||
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
||||
)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
|
@ -806,7 +839,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
MetricsAddr: viper.GetString("metrics_listen_addr"),
|
||||
GRPCAddr: viper.GetString("grpc_listen_addr"),
|
||||
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
||||
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
||||
DisableUpdateCheck: false,
|
||||
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
|
@ -823,12 +856,11 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
"ephemeral_node_inactivity_timeout",
|
||||
),
|
||||
|
||||
Database: GetDatabaseConfig(),
|
||||
Database: databaseConfig(),
|
||||
|
||||
TLS: GetTLSConfig(),
|
||||
TLS: tlsConfig(),
|
||||
|
||||
DNSConfig: DNSToTailcfgDNS(dnsConfig),
|
||||
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
|
||||
DNSConfig: dnsToTailcfgDNS(dnsConfig),
|
||||
|
||||
ACMEEmail: viper.GetString("acme_email"),
|
||||
ACMEURL: viper.GetString("acme_url"),
|
||||
|
@ -840,15 +872,14 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
||||
"oidc.only_start_if_oidc_is_available",
|
||||
),
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: oidcClientSecret,
|
||||
Scope: viper.GetStringSlice("oidc.scope"),
|
||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: oidcClientSecret,
|
||||
Scope: viper.GetStringSlice("oidc.scope"),
|
||||
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
||||
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
||||
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
||||
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
||||
Expiry: func() time.Duration {
|
||||
// if set to 0, we assume no expiry
|
||||
if value := viper.GetString("oidc.expiry"); value == "0" {
|
||||
|
@ -870,7 +901,7 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
LogTail: logTailConfig,
|
||||
RandomizeClientPort: randomizeClientPort,
|
||||
|
||||
Policy: GetPolicyConfig(),
|
||||
Policy: policyConfig(),
|
||||
|
||||
CLI: CLIConfig{
|
||||
Address: viper.GetString("cli.address"),
|
||||
|
@ -883,17 +914,15 @@ func GetHeadscaleConfig() (*Config, error) {
|
|||
|
||||
// TODO(kradalby): Document these settings when more stable
|
||||
Tuning: Tuning{
|
||||
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
|
||||
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
||||
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
|
||||
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
|
||||
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
||||
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||
"tuning.node_mapsession_buffered_chan_size",
|
||||
),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func IsCLIConfigured() bool {
|
||||
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
|
||||
}
|
||||
|
||||
type deprecator struct {
|
||||
warns set.Set[string]
|
||||
fatals set.Set[string]
|
||||
|
@ -905,14 +934,26 @@ func (d *deprecator) warnWithAlias(newKey, oldKey string) {
|
|||
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
|
||||
viper.RegisterAlias(newKey, oldKey)
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", oldKey, newKey, oldKey))
|
||||
d.warns.Add(
|
||||
fmt.Sprintf(
|
||||
"The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.",
|
||||
oldKey,
|
||||
newKey,
|
||||
oldKey,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
|
||||
func (d *deprecator) fatal(newKey, oldKey string) {
|
||||
func (d *deprecator) fatal(oldKey string) {
|
||||
if viper.IsSet(oldKey) {
|
||||
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
||||
d.fatals.Add(
|
||||
fmt.Sprintf(
|
||||
"The %q configuration key has been removed. Please see the changelog for more details.",
|
||||
oldKey,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -920,7 +961,14 @@ func (d *deprecator) fatal(newKey, oldKey string) {
|
|||
// If the new key is set, a warning is emitted instead.
|
||||
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
||||
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
|
||||
d.fatals.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
||||
d.fatals.Add(
|
||||
fmt.Sprintf(
|
||||
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
||||
oldKey,
|
||||
newKey,
|
||||
oldKey,
|
||||
),
|
||||
)
|
||||
} else if viper.IsSet(oldKey) {
|
||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
||||
}
|
||||
|
@ -929,14 +977,26 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
|||
// warn deprecates and adds an option to log a warning if the oldKey is set.
|
||||
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
||||
d.warns.Add(
|
||||
fmt.Sprintf(
|
||||
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
||||
oldKey,
|
||||
newKey,
|
||||
oldKey,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
|
||||
func (d *deprecator) warn(oldKey string) {
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", oldKey))
|
||||
d.warns.Add(
|
||||
fmt.Sprintf(
|
||||
"The %q configuration key is deprecated and has been removed. Please see the changelog for more details.",
|
||||
oldKey,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) {
|
|||
name: "unmarshal-dns-full-config",
|
||||
configPath: "testdata/dns_full.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
dns, err := DNS()
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -40,20 +42,19 @@ func TestReadConfig(t *testing.T) {
|
|||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||
},
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
UserNameInMagicDNS: true,
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dns-to-tailcfg.DNSConfig",
|
||||
configPath: "testdata/dns_full.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
dns, err := DNS()
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return DNSToTailcfgDNS(dns), nil
|
||||
return dnsToTailcfgDNS(dns), nil
|
||||
},
|
||||
want: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
|
@ -79,7 +80,7 @@ func TestReadConfig(t *testing.T) {
|
|||
name: "unmarshal-dns-full-no-magic",
|
||||
configPath: "testdata/dns_full_no_magic.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
dns, err := DNS()
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -97,20 +98,19 @@ func TestReadConfig(t *testing.T) {
|
|||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||
{Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||
},
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
UserNameInMagicDNS: true,
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dns-to-tailcfg.DNSConfig",
|
||||
configPath: "testdata/dns_full_no_magic.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
dns, err := DNS()
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return DNSToTailcfgDNS(dns), nil
|
||||
return dnsToTailcfgDNS(dns), nil
|
||||
},
|
||||
want: &tailcfg.DNSConfig{
|
||||
Proxied: false,
|
||||
|
@ -136,7 +136,7 @@ func TestReadConfig(t *testing.T) {
|
|||
name: "base-domain-in-server-url-err",
|
||||
configPath: "testdata/base-domain-in-server-url.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
return GetHeadscaleConfig()
|
||||
return LoadServerConfig()
|
||||
},
|
||||
want: nil,
|
||||
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
||||
|
@ -145,7 +145,7 @@ func TestReadConfig(t *testing.T) {
|
|||
name: "base-domain-not-in-server-url",
|
||||
configPath: "testdata/base-domain-not-in-server-url.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
cfg, err := GetHeadscaleConfig()
|
||||
cfg, err := LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func TestReadConfig(t *testing.T) {
|
|||
name: "policy-path-is-loaded",
|
||||
configPath: "testdata/policy-path-is-loaded.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
cfg, err := GetHeadscaleConfig()
|
||||
cfg, err := LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -232,11 +232,10 @@ func TestReadConfigFromEnv(t *testing.T) {
|
|||
{
|
||||
name: "unmarshal-dns-full-config",
|
||||
configEnv: map[string]string{
|
||||
"HEADSCALE_DNS_MAGIC_DNS": "true",
|
||||
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
|
||||
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
|
||||
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
|
||||
"HEADSCALE_DNS_USE_USERNAME_IN_MAGIC_DNS": "true",
|
||||
"HEADSCALE_DNS_MAGIC_DNS": "true",
|
||||
"HEADSCALE_DNS_BASE_DOMAIN": "example.com",
|
||||
"HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`,
|
||||
"HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com",
|
||||
|
||||
// TODO(kradalby): Figure out how to pass these as env vars
|
||||
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
|
||||
|
@ -245,7 +244,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
|||
setup: func(t *testing.T) (any, error) {
|
||||
t.Logf("all settings: %#v", viper.AllSettings())
|
||||
|
||||
dns, err := DNS()
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -264,8 +263,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
|||
ExtraRecords: []tailcfg.DNSRecord{
|
||||
// {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"},
|
||||
},
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
UserNameInMagicDNS: true,
|
||||
SearchDomains: []string{"test.com", "bar.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -289,3 +287,49 @@ func TestReadConfigFromEnv(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfigValidation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// defer os.RemoveAll(tmpDir)
|
||||
configYaml := []byte(`---
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: ""
|
||||
tls_cert_path: abc.pem
|
||||
noise:
|
||||
private_key_path: noise_private.key`)
|
||||
|
||||
// Populate a custom config file
|
||||
configFilePath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
|
||||
// Check configuration validation errors (1)
|
||||
err = LoadConfig(tmpDir, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = validateServerConfig()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
|
||||
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
|
||||
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
|
||||
|
||||
// Check configuration validation errors (2)
|
||||
configYaml = []byte(`---
|
||||
noise:
|
||||
private_key_path: noise_private.key
|
||||
server_url: http://127.0.0.1:8080
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||
`)
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
err = LoadConfig(tmpDir, false)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -393,7 +393,7 @@ func (node *Node) Proto() *v1.Node {
|
|||
return nodeProto
|
||||
}
|
||||
|
||||
func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
|
||||
func (node *Node) GetFQDN(baseDomain string) (string, error) {
|
||||
if node.GivenName == "" {
|
||||
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
|
||||
}
|
||||
|
@ -408,19 +408,6 @@ func (node *Node) GetFQDN(cfg *Config, baseDomain string) (string, error) {
|
|||
)
|
||||
}
|
||||
|
||||
if cfg.DNSUserNameInMagicDNS {
|
||||
if node.User.Name == "" {
|
||||
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
|
||||
}
|
||||
|
||||
hostname = fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
node.GivenName,
|
||||
node.User.Name,
|
||||
baseDomain,
|
||||
)
|
||||
}
|
||||
|
||||
if len(hostname) > MaxHostnameLength {
|
||||
return "", fmt.Errorf(
|
||||
"failed to create valid FQDN (%s): %w",
|
||||
|
|
|
@ -127,76 +127,10 @@ func TestNodeFQDN(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
node Node
|
||||
cfg Config
|
||||
domain string
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "all-set-with-username",
|
||||
node: Node{
|
||||
GivenName: "test",
|
||||
User: User{
|
||||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: true,
|
||||
},
|
||||
domain: "example.com",
|
||||
want: "test.user.example.com",
|
||||
},
|
||||
{
|
||||
name: "no-given-name-with-username",
|
||||
node: Node{
|
||||
User: User{
|
||||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: true,
|
||||
},
|
||||
domain: "example.com",
|
||||
wantErr: "failed to create valid FQDN: node has no given name",
|
||||
},
|
||||
{
|
||||
name: "no-user-name-with-username",
|
||||
node: Node{
|
||||
GivenName: "test",
|
||||
User: User{},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: true,
|
||||
},
|
||||
domain: "example.com",
|
||||
wantErr: "failed to create valid FQDN: node user has no name",
|
||||
},
|
||||
{
|
||||
name: "no-magic-dns-with-username",
|
||||
node: Node{
|
||||
GivenName: "test",
|
||||
User: User{
|
||||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: false,
|
||||
},
|
||||
DNSUserNameInMagicDNS: true,
|
||||
},
|
||||
domain: "example.com",
|
||||
want: "test.user.example.com",
|
||||
},
|
||||
{
|
||||
name: "no-dnsconfig-with-username",
|
||||
node: Node{
|
||||
|
@ -216,12 +150,6 @@ func TestNodeFQDN(t *testing.T) {
|
|||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: false,
|
||||
},
|
||||
domain: "example.com",
|
||||
want: "test.example.com",
|
||||
},
|
||||
|
@ -232,46 +160,16 @@ func TestNodeFQDN(t *testing.T) {
|
|||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: false,
|
||||
},
|
||||
domain: "example.com",
|
||||
wantErr: "failed to create valid FQDN: node has no given name",
|
||||
},
|
||||
{
|
||||
name: "no-user-name",
|
||||
name: "too-long-username",
|
||||
node: Node{
|
||||
GivenName: "test",
|
||||
User: User{},
|
||||
GivenName: "useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111",
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: true,
|
||||
},
|
||||
DNSUserNameInMagicDNS: false,
|
||||
},
|
||||
domain: "example.com",
|
||||
want: "test.example.com",
|
||||
},
|
||||
{
|
||||
name: "no-magic-dns",
|
||||
node: Node{
|
||||
GivenName: "test",
|
||||
User: User{
|
||||
Name: "user",
|
||||
},
|
||||
},
|
||||
cfg: Config{
|
||||
DNSConfig: &tailcfg.DNSConfig{
|
||||
Proxied: false,
|
||||
},
|
||||
DNSUserNameInMagicDNS: false,
|
||||
},
|
||||
domain: "example.com",
|
||||
want: "test.example.com",
|
||||
domain: "example.com",
|
||||
wantErr: "failed to create valid FQDN (useruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruseruser11111111111111111111111111111111111111113444444444444444444444444444444444444444444444444444444444444444444444441111111111111111111111111111111111111111111111111111111111111111111111.example.com): hostname too long, cannot except 255 ASCII chars",
|
||||
},
|
||||
{
|
||||
name: "no-dnsconfig",
|
||||
|
@ -288,7 +186,9 @@ func TestNodeFQDN(t *testing.T) {
|
|||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.node.GetFQDN(&tc.cfg, tc.domain)
|
||||
got, err := tc.node.GetFQDN(tc.domain)
|
||||
|
||||
t.Logf("GOT: %q, %q", got, tc.domain)
|
||||
|
||||
if (err != nil) && (err.Error() != tc.wantErr) {
|
||||
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strconv"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
|
@ -16,19 +17,57 @@ import (
|
|||
// that contain our machines.
|
||||
type User struct {
|
||||
gorm.Model
|
||||
|
||||
// Username for the user, is used if email is empty
|
||||
// Should not be used, please use Username().
|
||||
Name string `gorm:"unique"`
|
||||
|
||||
// Typically a full name of the user
|
||||
DisplayName string
|
||||
|
||||
// Email of the user
|
||||
// Should not be used, please use Username().
|
||||
Email string
|
||||
|
||||
// Unique identifier of the user from OIDC,
|
||||
// comes from `sub` claim in the OIDC token
|
||||
// and is used to lookup the user.
|
||||
ProviderIdentifier string `gorm:"index"`
|
||||
|
||||
// Provider is the origin of the user account,
|
||||
// same as RegistrationMethod, without authkey.
|
||||
Provider string
|
||||
|
||||
ProfilePicURL string
|
||||
}
|
||||
|
||||
// Username is the main way to get the username of a user,
|
||||
// it will return the email if it exists, the name if it exists,
|
||||
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
|
||||
// Email and OIDCIdentifier will be set when the user has headscale
|
||||
// enabled with OIDC, which means that there is a domain involved which
|
||||
// should be used throughout headscale, in information returned to the
|
||||
// user and the Policy engine.
|
||||
func (u *User) Username() string {
|
||||
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
||||
}
|
||||
|
||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||
// it will return the Username.
|
||||
func (u *User) DisplayNameOrUsername() string {
|
||||
return cmp.Or(u.DisplayName, u.Username())
|
||||
}
|
||||
|
||||
// TODO(kradalby): See if we can fill in Gravatar here
|
||||
func (u *User) profilePicURL() string {
|
||||
return ""
|
||||
return u.ProfilePicURL
|
||||
}
|
||||
|
||||
func (u *User) TailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
LoginName: u.Name,
|
||||
DisplayName: u.Name,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
Logins: []tailcfg.LoginID{},
|
||||
Created: u.CreatedAt,
|
||||
|
@ -41,9 +80,9 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
|
|||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
// TODO(kradalby): this should reflect registration method.
|
||||
Provider: "",
|
||||
LoginName: u.Name,
|
||||
DisplayName: u.Name,
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
|
||||
|
@ -53,8 +92,8 @@ func (u *User) TailscaleLogin() *tailcfg.Login {
|
|||
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
return tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
LoginName: u.Name,
|
||||
DisplayName: u.Name,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
}
|
||||
|
@ -66,3 +105,27 @@ func (n *User) Proto() *v1.User {
|
|||
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCClaims struct {
|
||||
// Sub is the user's unique identifier at the provider.
|
||||
Sub string `json:"sub"`
|
||||
|
||||
// Name is the user's full name.
|
||||
Name string `json:"name,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"email_verified,omitempty"`
|
||||
ProfilePictureURL string `json:"picture,omitempty"`
|
||||
Username string `json:"preferred_username,omitempty"`
|
||||
}
|
||||
|
||||
// FromClaim overrides a User from OIDC claims.
|
||||
// All fields will be updated, except for the ID.
|
||||
func (u *User) FromClaim(claims *OIDCClaims) {
|
||||
u.ProviderIdentifier = claims.Sub
|
||||
u.DisplayName = claims.Username
|
||||
u.Email = claims.Email
|
||||
u.Name = claims.Username
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
u.Provider = util.RegisterMethodOIDC
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
@ -25,38 +24,6 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
|||
|
||||
var ErrInvalidUserName = errors.New("invalid user name")
|
||||
|
||||
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
|
||||
strip := viper.GetBool("oidc.strip_email_domain")
|
||||
|
||||
return NormalizeToFQDNRules(name, strip)
|
||||
}
|
||||
|
||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, "'", "")
|
||||
atIdx := strings.Index(name, "@")
|
||||
if stripEmailDomain && atIdx > 0 {
|
||||
name = name[:atIdx]
|
||||
} else {
|
||||
name = strings.ReplaceAll(name, "@", ".")
|
||||
}
|
||||
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||
|
||||
for _, elt := range strings.Split(name, ".") {
|
||||
if len(elt) > LabelHostnameLength {
|
||||
return "", fmt.Errorf(
|
||||
"label %v is more than 63 chars: %w",
|
||||
elt,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func CheckForFQDNRules(name string) error {
|
||||
if len(name) > LabelHostnameLength {
|
||||
return fmt.Errorf(
|
||||
|
|
|
@ -7,100 +7,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNormalizeToFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
stripEmailDomain bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normalize simple name",
|
||||
args: args{
|
||||
name: "normalize-simple.name",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "normalize-simple.name",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email domain should be removed",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "foo.bar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "strip enabled no email passed as argument",
|
||||
args: args{
|
||||
name: "not-email-and-strip-enabled",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "not-email-and-strip-enabled",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize complex email",
|
||||
args: args{
|
||||
name: "foo.bar+complex-email@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar-complex-email.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user name with space",
|
||||
args: args{
|
||||
name: "name space",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "name-space",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user with quote",
|
||||
args: args{
|
||||
name: "Jamie's iPhone 5",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "jamies-iphone-5",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
||||
err,
|
||||
tt.wantErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
|
|
|
@ -276,7 +276,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
|||
hsic.WithACLPolicy(&testCase.policy),
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErr(t, err)
|
||||
|
@ -316,7 +316,7 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
|||
},
|
||||
1,
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErr(t, err)
|
||||
|
@ -373,7 +373,7 @@ func TestACLDenyAllPort80(t *testing.T) {
|
|||
},
|
||||
4,
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErr(t, err)
|
||||
|
@ -417,7 +417,7 @@ func TestACLAllowUserDst(t *testing.T) {
|
|||
},
|
||||
2,
|
||||
)
|
||||
// defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErr(t, err)
|
||||
|
@ -473,7 +473,7 @@ func TestACLAllowStarDst(t *testing.T) {
|
|||
},
|
||||
2,
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErr(t, err)
|
||||
|
@ -534,7 +534,7 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
|||
},
|
||||
3,
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErr(t, err)
|
||||
|
@ -672,7 +672,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
|||
&testCase.policy,
|
||||
2,
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// Since user/users dont matter here, we basically expect that some clients
|
||||
// will be assigned these ips and that we can pick them up for our own use.
|
||||
|
@ -1021,7 +1021,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
|
|
|
@ -48,7 +48,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
|||
scenario := AuthOIDCScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -62,7 +62,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
|||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
|
@ -108,7 +107,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
scenario := AuthOIDCScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 3,
|
||||
|
@ -121,7 +120,6 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret,
|
||||
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain),
|
||||
"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1",
|
||||
}
|
||||
|
||||
|
@ -276,7 +274,6 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
|||
),
|
||||
ClientID: "superclient",
|
||||
ClientSecret: "supersecret",
|
||||
StripEmaildomain: true,
|
||||
OnlyStartIfOIDCIsAvailable: true,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
|||
scenario := AuthWebFlowScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -73,7 +73,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
|||
scenario := AuthWebFlowScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -34,7 +35,7 @@ func TestUserCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 0,
|
||||
|
@ -114,7 +115,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 0,
|
||||
|
@ -256,7 +257,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 0,
|
||||
|
@ -319,7 +320,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 0,
|
||||
|
@ -397,7 +398,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user1: 1,
|
||||
|
@ -491,7 +492,7 @@ func TestApiKeyCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 0,
|
||||
|
@ -659,7 +660,7 @@ func TestNodeTagCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 0,
|
||||
|
@ -735,13 +736,7 @@ func TestNodeTagCommand(t *testing.T) {
|
|||
|
||||
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
|
||||
|
||||
// try to set a wrong tag and retrieve the error
|
||||
type errOutput struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
var errorOutput errOutput
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
|
@ -750,10 +745,8 @@ func TestNodeTagCommand(t *testing.T) {
|
|||
"-t", "wrong-tag",
|
||||
"--output", "json",
|
||||
},
|
||||
&errorOutput,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'")
|
||||
assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
|
||||
|
||||
// Test list all nodes after added seconds
|
||||
resultMachines := make([]*v1.Node, len(machineKeys))
|
||||
|
@ -792,7 +785,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
|
@ -842,7 +835,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
|
@ -905,7 +898,7 @@ func TestNodeCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"node-user": 0,
|
||||
|
@ -1146,7 +1139,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"node-expire-user": 0,
|
||||
|
@ -1273,7 +1266,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"node-rename-command": 0,
|
||||
|
@ -1398,18 +1391,17 @@ func TestNodeRenameCommand(t *testing.T) {
|
|||
assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5")
|
||||
|
||||
// Test failure for too long names
|
||||
result, err := headscale.Execute(
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"rename",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", listAll[4].GetId()),
|
||||
"testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890",
|
||||
strings.Repeat("t", 64),
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, result, "not be over 63 chars")
|
||||
assert.ErrorContains(t, err, "not be over 63 chars")
|
||||
|
||||
var listAllAfterRenameAttempt []v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
|
@ -1440,7 +1432,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"old-user": 0,
|
||||
|
@ -1536,7 +1528,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
|||
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
|
||||
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
|
||||
|
||||
moveToNonExistingNSResult, err := headscale.Execute(
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
|
@ -1549,11 +1541,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
|||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Contains(
|
||||
assert.ErrorContains(
|
||||
t,
|
||||
moveToNonExistingNSResult,
|
||||
err,
|
||||
"user not found",
|
||||
)
|
||||
assert.Equal(t, node.GetUser().GetName(), "new-user")
|
||||
|
@ -1603,7 +1593,7 @@ func TestPolicyCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"policy-user": 0,
|
||||
|
@ -1683,7 +1673,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"policy-user": 1,
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
)
|
||||
|
||||
type ControlServer interface {
|
||||
Shutdown() error
|
||||
SaveLog(string) error
|
||||
Shutdown() (string, string, error)
|
||||
SaveLog(string) (string, string, error)
|
||||
SaveProfile(string) error
|
||||
Execute(command []string) (string, error)
|
||||
WriteFile(path string, content []byte) error
|
||||
|
|
|
@ -17,7 +17,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"magicdns1": len(MustTestVersions),
|
||||
|
@ -208,7 +208,7 @@ func TestValidateResolvConf(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"resolvconf1": 3,
|
||||
|
|
|
@ -17,10 +17,10 @@ func SaveLog(
|
|||
pool *dockertest.Pool,
|
||||
resource *dockertest.Resource,
|
||||
basePath string,
|
||||
) error {
|
||||
) (string, string, error) {
|
||||
err := os.MkdirAll(basePath, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var stdout bytes.Buffer
|
||||
|
@ -41,28 +41,30 @@ func SaveLog(
|
|||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
|
||||
|
||||
stdoutPath := path.Join(basePath, resource.Container.Name+".stdout.log")
|
||||
err = os.WriteFile(
|
||||
path.Join(basePath, resource.Container.Name+".stdout.log"),
|
||||
stdoutPath,
|
||||
stdout.Bytes(),
|
||||
filePerm,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
stderrPath := path.Join(basePath, resource.Container.Name+".stderr.log")
|
||||
err = os.WriteFile(
|
||||
path.Join(basePath, resource.Container.Name+".stderr.log"),
|
||||
stderrPath,
|
||||
stderr.Bytes(),
|
||||
filePerm,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return nil
|
||||
return stdoutPath, stderrPath, nil
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestDERPServerScenario(t *testing.T) {
|
|||
Scenario: baseScenario,
|
||||
tsicNetworks: map[string]*dockertest.Network{},
|
||||
}
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
|
|
@ -27,7 +27,7 @@ func TestPingAllByIP(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// TODO(kradalby): it does not look like the user thing works, only second
|
||||
// get created? maybe only when many?
|
||||
|
@ -71,7 +71,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -109,7 +109,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -228,7 +228,7 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -313,7 +313,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -427,7 +427,7 @@ func TestPingAllByHostname(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user3": len(MustTestVersions),
|
||||
|
@ -476,7 +476,7 @@ func TestTaildrop(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"taildrop": len(MustTestVersions),
|
||||
|
@ -637,7 +637,7 @@ func TestExpireNode(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -763,7 +763,7 @@ func TestNodeOnlineStatus(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"user1": len(MustTestVersions),
|
||||
|
@ -878,7 +878,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// TODO(kradalby): it does not look like the user thing works, only second
|
||||
// get created? maybe only when many?
|
||||
|
@ -954,3 +954,102 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
|||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
}
|
||||
|
||||
func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// TODO(kradalby): it does not look like the user thing works, only second
|
||||
// get created? maybe only when many?
|
||||
spec := map[string]int{
|
||||
"user1": 1,
|
||||
"user2": 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{},
|
||||
hsic.WithTestName("deletenocrash"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithHostnameAsServerURL(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
assertNoErrListClientIPs(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||
return x.String()
|
||||
})
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Test list all nodes after added otherUser
|
||||
var nodeList []v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodeList,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, nodeList, 2)
|
||||
assert.True(t, nodeList[0].Online)
|
||||
assert.True(t, nodeList[1].Online)
|
||||
|
||||
// Delete the first node, which is online
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"delete",
|
||||
"--identifier",
|
||||
// Delete the last added machine
|
||||
fmt.Sprintf("%d", nodeList[0].Id),
|
||||
"--output",
|
||||
"json",
|
||||
"--force",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Ensure that the node has been deleted, this did not occur due to a panic.
|
||||
var nodeListAfter []v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodeListAfter,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, nodeListAfter, 1)
|
||||
assert.True(t, nodeListAfter[0].Online)
|
||||
assert.Equal(t, nodeList[1].Id, nodeListAfter[0].Id)
|
||||
|
||||
}
|
||||
|
|
|
@ -398,8 +398,8 @@ func (t *HeadscaleInContainer) hasTLS() bool {
|
|||
}
|
||||
|
||||
// Shutdown stops and cleans up the Headscale container.
|
||||
func (t *HeadscaleInContainer) Shutdown() error {
|
||||
err := t.SaveLog("/tmp/control")
|
||||
func (t *HeadscaleInContainer) Shutdown() (string, string, error) {
|
||||
stdoutPath, stderrPath, err := t.SaveLog("/tmp/control")
|
||||
if err != nil {
|
||||
log.Printf(
|
||||
"Failed to save log from control: %s",
|
||||
|
@ -458,12 +458,12 @@ func (t *HeadscaleInContainer) Shutdown() error {
|
|||
t.pool.Purge(t.pgContainer)
|
||||
}
|
||||
|
||||
return t.pool.Purge(t.container)
|
||||
return stdoutPath, stderrPath, t.pool.Purge(t.container)
|
||||
}
|
||||
|
||||
// SaveLog saves the current stdout log of the container to a path
|
||||
// on the host system.
|
||||
func (t *HeadscaleInContainer) SaveLog(path string) error {
|
||||
func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) {
|
||||
return dockertestutil.SaveLog(t.pool, t.container, path)
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 3,
|
||||
|
@ -254,7 +254,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 3,
|
||||
|
@ -826,7 +826,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 1,
|
||||
|
@ -968,7 +968,7 @@ func TestAutoApprovedSubRoute2068(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 1,
|
||||
|
@ -1059,7 +1059,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
user: 2,
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
"github.com/ory/dockertest/v3"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/envknob"
|
||||
)
|
||||
|
@ -187,13 +189,9 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
|
||||
// and networks associated with it.
|
||||
// In addition, it will save the logs of the ControlServer to `/tmp/control` in the
|
||||
// environment running the tests.
|
||||
func (s *Scenario) Shutdown() {
|
||||
func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
||||
s.controlServers.Range(func(_ string, control ControlServer) bool {
|
||||
err := control.Shutdown()
|
||||
stdoutPath, stderrPath, err := control.Shutdown()
|
||||
if err != nil {
|
||||
log.Printf(
|
||||
"Failed to shut down control: %s",
|
||||
|
@ -201,6 +199,16 @@ func (s *Scenario) Shutdown() {
|
|||
)
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
stdout, err := os.ReadFile(stdoutPath)
|
||||
assert.NoError(t, err)
|
||||
assert.NotContains(t, string(stdout), "panic")
|
||||
|
||||
stderr, err := os.ReadFile(stderrPath)
|
||||
assert.NoError(t, err)
|
||||
assert.NotContains(t, string(stderr), "panic")
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
@ -224,6 +232,14 @@ func (s *Scenario) Shutdown() {
|
|||
// }
|
||||
}
|
||||
|
||||
// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient)
|
||||
// and networks associated with it.
|
||||
// In addition, it will save the logs of the ControlServer to `/tmp/control` in the
|
||||
// environment running the tests.
|
||||
func (s *Scenario) Shutdown() {
|
||||
s.ShutdownAssertNoPanics(nil)
|
||||
}
|
||||
|
||||
// Users returns the name of all users associated with the Scenario.
|
||||
func (s *Scenario) Users() []string {
|
||||
users := make([]string, 0)
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
// This file is intended to "test the test framework", by proxy it will also test
|
||||
// some Headcsale/Tailscale stuff, but mostly in very simple ways.
|
||||
// some Headscale/Tailscale stuff, but mostly in very simple ways.
|
||||
|
||||
func IntegrationSkip(t *testing.T) {
|
||||
t.Helper()
|
||||
|
@ -35,7 +35,7 @@ func TestHeadscale(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
|
@ -80,7 +80,7 @@ func TestCreateTailscale(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
scenario.users[user] = &User{
|
||||
Clients: make(map[string]TailscaleClient),
|
||||
|
@ -116,7 +116,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
|||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
|
|
|
@ -111,7 +111,7 @@ func TestSSHOneUserToAll(t *testing.T) {
|
|||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
@ -176,7 +176,7 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
|||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
nsOneClients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErrListClients(t, err)
|
||||
|
@ -222,7 +222,7 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
|||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
@ -271,7 +271,7 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
|||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
@ -327,7 +327,7 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
|||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
defer scenario.Shutdown()
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
ssh1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
assertNoErrListClients(t, err)
|
||||
|
|
|
@ -998,7 +998,9 @@ func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
|
|||
// SaveLog saves the current stdout log of the container to a path
|
||||
// on the host system.
|
||||
func (t *TailscaleInContainer) SaveLog(path string) error {
|
||||
return dockertestutil.SaveLog(t.pool, t.container, path)
|
||||
// TODO(kradalby): Assert if tailscale logs contains panics.
|
||||
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadFile reads a file from the Tailscale container.
|
||||
|
|
11
mkdocs.yml
11
mkdocs.yml
|
@ -10,7 +10,7 @@ repo_name: juanfont/headscale
|
|||
repo_url: https://github.com/juanfont/headscale
|
||||
|
||||
# Copyright
|
||||
copyright: Copyright © 2023 Headscale authors
|
||||
copyright: Copyright © 2024 Headscale authors
|
||||
|
||||
# Configuration
|
||||
theme:
|
||||
|
@ -55,6 +55,13 @@ theme:
|
|||
favicon: assets/favicon.png
|
||||
logo: ./logo/headscale3-dots.svg
|
||||
|
||||
# Excludes
|
||||
exclude_docs: |
|
||||
/packaging/README.md
|
||||
/packaging/postinstall.sh
|
||||
/packaging/postremove.sh
|
||||
/requirements.txt
|
||||
|
||||
# Plugins
|
||||
plugins:
|
||||
- search:
|
||||
|
@ -139,5 +146,5 @@ nav:
|
|||
- Remote CLI: remote-cli.md
|
||||
- Usage:
|
||||
- Android: android-client.md
|
||||
- Apple: apple-client.md
|
||||
- Windows: windows-client.md
|
||||
- iOS: iOS-client.md
|
||||
|
|
Loading…
Reference in a new issue