Only load needed part of configuration (#2109)

This commit is contained in:
Kristoffer Dalby 2024-09-07 09:23:58 +02:00 committed by GitHub
parent f368ed01ed
commit 8a3a0fee3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 196 additions and 324 deletions

1
.gitignore vendored
View file

@ -22,6 +22,7 @@ dist/
/headscale
config.json
config.yaml
config*.yaml
derp.yaml
*.hujson
*.key

View file

@ -72,6 +72,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792)
- Fix for registering nodes using preauthkeys when running on a postgres database in a non-UTC timezone. [#764](https://github.com/juanfont/headscale/issues/764)
- Make sure integration tests cover postgres for all scenarios
- CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109)
- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109)
## 0.22.3 (2023-05-12)

View file

@ -54,7 +54,7 @@ var listAPIKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -67,14 +67,10 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetApiKeys(), "", output)
return
}
tableData := pterm.TableData{
@ -102,8 +98,6 @@ var listAPIKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -119,9 +113,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
log.Trace().
Msg("Preparing to create ApiKey")
request := &v1.CreateApiKeyRequest{}
durationStr, _ := cmd.Flags().GetString("expiration")
@ -133,19 +124,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -156,8 +141,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Cannot create Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
@ -178,11 +161,9 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -197,8 +178,6 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)
@ -219,11 +198,9 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -238,8 +215,6 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key deleted", output)

View file

@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{
Short: "Test the configuration.",
Long: "Run a test of the configuration and exit.",
Run: func(cmd *cobra.Command, args []string) {
_, err := getHeadscaleApp()
_, err := newHeadscaleServerWithConfig()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing")
}

View file

@ -64,11 +64,9 @@ var createNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -79,8 +77,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return
}
machineKey, err := cmd.Flags().GetString("key")
@ -90,8 +86,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return
}
var mkey key.MachinePublic
@ -102,8 +96,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
@ -113,8 +105,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
request := &v1.DebugCreateNodeRequest{
@ -131,8 +121,6 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
output,
)
return
}
SuccessOutput(response.GetNode(), "Node created", output)

View file

@ -116,11 +116,9 @@ var registerNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -131,8 +129,6 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node key from flag: %s", err),
output,
)
return
}
request := &v1.RegisterNodeRequest{
@ -150,8 +146,6 @@ var registerNodeCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(
@ -169,17 +163,13 @@ var listNodesCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
showTags, err := cmd.Flags().GetBool("tags")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -194,21 +184,15 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
return
}
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -218,8 +202,6 @@ var listNodesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -243,7 +225,7 @@ var expireNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -286,7 +268,7 @@ var renameNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -335,7 +317,7 @@ var deleteNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -435,7 +417,7 @@ var moveNodeCmd = &cobra.Command{
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -508,7 +490,7 @@ be assigned to nodes.`,
return
}
if confirm {
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -681,7 +663,7 @@ var tagCmd = &cobra.Command{
Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()

View file

@ -1,6 +1,7 @@
package cli
import (
"fmt"
"io"
"os"
@ -30,7 +31,8 @@ var getPolicy = &cobra.Command{
Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) {
ctx, client, conn, cancel := getHeadscaleCLIClient()
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -38,13 +40,13 @@ var getPolicy = &cobra.Command{
response, err := client.GetPolicy(ctx, request)
if err != nil {
log.Fatal().Err(err).Msg("Failed to get the policy")
return
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
}
// TODO(pallabpain): Maybe print this better?
SuccessOutput("", response.GetPolicy(), "hujson")
// This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", response.GetPolicy(), "")
},
}
@ -56,33 +58,28 @@ var setPolicy = &cobra.Command{
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
if err != nil {
log.Fatal().Err(err).Msg("Error opening the policy file")
return
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
log.Fatal().Err(err).Msg("Error reading the policy file")
return
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
}
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil {
log.Fatal().Err(err).Msg("Failed to set ACL Policy")
return
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
}
SuccessOutput(nil, "Policy updated.", "")

View file

@ -60,11 +60,9 @@ var listPreAuthKeys = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -85,8 +83,6 @@ var listPreAuthKeys = &cobra.Command{
if output != "" {
SuccessOutput(response.GetPreAuthKeys(), "", output)
return
}
tableData := pterm.TableData{
@ -134,8 +130,6 @@ var listPreAuthKeys = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -150,20 +144,12 @@ var createPreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
reusable, _ := cmd.Flags().GetBool("reusable")
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
tags, _ := cmd.Flags().GetStringSlice("tags")
log.Trace().
Bool("reusable", reusable).
Bool("ephemeral", ephemeral).
Str("user", user).
Msg("Preparing to create preauthkey")
request := &v1.CreatePreAuthKeyRequest{
User: user,
Reusable: reusable,
@ -180,8 +166,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
@ -192,7 +176,7 @@ var createPreAuthKeyCmd = &cobra.Command{
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -203,8 +187,6 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
@ -227,11 +209,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -247,8 +227,6 @@ var expirePreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)

View file

@ -9,6 +9,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/tcnksm/go-latest"
)
@ -49,11 +50,6 @@ func initConfig() {
}
}
cfg, err := types.GetHeadscaleConfig()
if err != nil {
log.Fatal().Err(err).Msg("Failed to read headscale configuration")
}
machineOutput := HasMachineOutputFlag()
// If the user has requested a "node" readable format,
@ -62,11 +58,13 @@ func initConfig() {
zerolog.SetGlobalLevel(zerolog.Disabled)
}
if cfg.Log.Format == types.JSONLogFormat {
log.Logger = log.Output(os.Stdout)
}
// logFormat := viper.GetString("log.format")
// if logFormat == types.JSONLogFormat {
// log.Logger = log.Output(os.Stdout)
// }
if !cfg.DisableUpdateCheck && !machineOutput {
disableUpdateCheck := viper.GetBool("disable_check_updates")
if !disableUpdateCheck && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
Version != "dev" {
githubTag := &latest.GithubTag{

View file

@ -64,11 +64,9 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -82,14 +80,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetRoutes(), "", output)
return
}
routes = response.GetRoutes()
@ -103,14 +97,10 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetRoutes(), "", output)
return
}
routes = response.GetRoutes()
@ -119,8 +109,6 @@ var listRoutesCmd = &cobra.Command{
tableData := routesToPtables(routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
@ -130,8 +118,6 @@ var listRoutesCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -150,11 +136,9 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -167,14 +151,10 @@ var enableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}
@ -193,11 +173,9 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -210,14 +188,10 @@ var disableRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}
@ -236,11 +210,9 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -253,14 +225,10 @@ var deleteRouteCmd = &cobra.Command{
fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response, "", output)
return
}
},
}

View file

@ -16,7 +16,7 @@ var serveCmd = &cobra.Command{
return nil
},
Run: func(cmd *cobra.Command, args []string) {
app, err := getHeadscaleApp()
app, err := newHeadscaleServerWithConfig()
if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing")
}

View file

@ -44,7 +44,7 @@ var createUserCmd = &cobra.Command{
userName := args[0]
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -63,8 +63,6 @@ var createUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response.GetUser(), "User created", output)
@ -91,7 +89,7 @@ var destroyUserCmd = &cobra.Command{
Name: userName,
}
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -102,8 +100,6 @@ var destroyUserCmd = &cobra.Command{
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
output,
)
return
}
confirm := false
@ -134,8 +130,6 @@ var destroyUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response, "User destroyed", output)
} else {
@ -151,7 +145,7 @@ var listUsersCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -164,14 +158,10 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetUsers(), "", output)
return
}
tableData := pterm.TableData{{"ID", "Name", "Created"}}
@ -192,8 +182,6 @@ var listUsersCmd = &cobra.Command{
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -213,7 +201,7 @@ var renameUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := getHeadscaleCLIClient()
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
@ -232,8 +220,6 @@ var renameUserCmd = &cobra.Command{
),
output,
)
return
}
SuccessOutput(response.GetUser(), "User renamed", output)

View file

@ -23,8 +23,8 @@ const (
SocketWritePermissions = 0o666
)
func getHeadscaleApp() (*hscontrol.Headscale, error) {
cfg, err := types.GetHeadscaleConfig()
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
cfg, err := types.LoadServerConfig()
if err != nil {
return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w",
@ -40,8 +40,8 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := types.GetHeadscaleConfig()
func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := types.LoadCLIConfig()
if err != nil {
log.Fatal().
Err(err).
@ -130,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
return ctx, client, conn, cancel
}
func SuccessOutput(result interface{}, override string, outputFormat string) {
func output(result interface{}, override string, outputFormat string) string {
var jsonBytes []byte
var err error
switch outputFormat {
@ -151,21 +151,26 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
}
default:
// nolint
fmt.Println(override)
return
return override
}
// nolint
fmt.Println(string(jsonBytes))
return string(jsonBytes)
}
// SuccessOutput prints the result to stdout and exits with status code 0.
func SuccessOutput(result interface{}, override string, outputFormat string) {
fmt.Println(output(result, override, outputFormat))
os.Exit(0)
}
// ErrorOutput prints an error message to stderr and exits with status code 1.
func ErrorOutput(errResult error, override string, outputFormat string) {
type errOutput struct {
Error string `json:"error"`
}
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errResult.Error()}, override, outputFormat))
os.Exit(1)
}
func HasMachineOutputFlag() bool {

View file

@ -4,7 +4,6 @@ import (
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
@ -113,60 +112,3 @@ func (*Suite) TestConfigLoading(c *check.C) {
c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false)
c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false)
}
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
}
func (*Suite) TestTLSConfigValidation(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***")
c.Assert(
tmp,
check.Matches,
".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
c.Assert(
tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
writeConfig(c, tmpDir, configYaml)
err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)
}

View file

@ -684,7 +684,7 @@ func (api headscaleV1APIServer) GetPolicy(
case types.PolicyModeDB:
p, err := api.h.db.GetPolicy()
if err != nil {
return nil, err
return nil, fmt.Errorf("loading ACL from database: %w", err)
}
return &v1.GetPolicyResponse{
@ -696,20 +696,20 @@ func (api headscaleV1APIServer) GetPolicy(
absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
f, err := os.Open(absPath)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
}
defer f.Close()
b, err := io.ReadAll(f)
if err != nil {
return nil, err
return nil, fmt.Errorf("reading policy from file: %w", err)
}
return &v1.GetPolicyResponse{Policy: string(b)}, nil
}
return nil, nil
return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
}
func (api headscaleV1APIServer) SetPolicy(

View file

@ -212,6 +212,12 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}
// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
// It has to be called before LoadServerConfig and LoadCLIConfig.
// The configuration is not validated and the caller should check for errors
// using a validation function.
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
@ -284,14 +290,14 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
}
return nil
}
func validateServerConfig() error {
depr := deprecator{
warns: make(set.Set[string]),
fatals: make(set.Set[string]),
@ -360,12 +366,12 @@ func LoadConfig(path string, isFile bool) error {
if errorText != "" {
// nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
}
func GetTLSConfig() TLSConfig {
return nil
}
func tlsConfig() TLSConfig {
return TLSConfig{
LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"),
@ -384,7 +390,7 @@ func GetTLSConfig() TLSConfig {
}
}
func GetDERPConfig() DERPConfig {
func derpConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
@ -445,7 +451,7 @@ func GetDERPConfig() DERPConfig {
}
}
func GetLogTailConfig() LogTailConfig {
func logtailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{
@ -453,7 +459,7 @@ func GetLogTailConfig() LogTailConfig {
}
}
func GetPolicyConfig() PolicyConfig {
func policyConfig() PolicyConfig {
policyPath := viper.GetString("policy.path")
policyMode := viper.GetString("policy.mode")
@ -463,7 +469,7 @@ func GetPolicyConfig() PolicyConfig {
}
}
func GetLogConfig() LogConfig {
func logConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
@ -473,9 +479,9 @@ func GetLogConfig() LogConfig {
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
case JSONLogFormat:
logFormat = JSONLogFormat
case "text":
case TextLogFormat:
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
@ -491,7 +497,7 @@ func GetLogConfig() LogConfig {
}
}
func GetDatabaseConfig() DatabaseConfig {
func databaseConfig() DatabaseConfig {
debug := viper.GetBool("database.debug")
type_ := viper.GetString("database.type")
@ -543,7 +549,7 @@ func GetDatabaseConfig() DatabaseConfig {
}
}
func DNS() (DNSConfig, error) {
func dns() (DNSConfig, error) {
var dns DNSConfig
// TODO: Use this instead of manually getting settings when
@ -575,12 +581,12 @@ func DNS() (DNSConfig, error) {
return dns, nil
}
// GlobalResolvers returns the global DNS resolvers
// globalResolvers returns the global DNS resolvers
// defined in the config file.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
var resolvers []*dnstype.Resolver
for _, nsStr := range d.Nameservers.Global {
@ -613,11 +619,11 @@ func (d *DNSConfig) GlobalResolvers() []*dnstype.Resolver {
return resolvers
}
// SplitResolvers returns a map of domain to DNS resolvers.
// splitResolvers returns a map of domain to DNS resolvers.
// If a nameserver is a valid IP, it will be used as a regular resolver.
// If a nameserver is a valid URL, it will be used as a DoH resolver.
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
routes := make(map[string][]*dnstype.Resolver)
for domain, nameservers := range d.Nameservers.Split {
var resolvers []*dnstype.Resolver
@ -653,7 +659,7 @@ func (d *DNSConfig) SplitResolvers() map[string][]*dnstype.Resolver {
return routes
}
func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg := tailcfg.DNSConfig{}
if dns.BaseDomain == "" && dns.MagicDNS {
@ -662,9 +668,9 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
cfg.Proxied = dns.MagicDNS
cfg.ExtraRecords = dns.ExtraRecords
cfg.Resolvers = dns.GlobalResolvers()
cfg.Resolvers = dns.globalResolvers()
routes := dns.SplitResolvers()
routes := dns.splitResolvers()
cfg.Routes = routes
if dns.BaseDomain != "" {
cfg.Domains = []string{dns.BaseDomain}
@ -674,7 +680,7 @@ func DNSToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
return &cfg
}
func PrefixV4() (*netip.Prefix, error) {
func prefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
if prefixV4Str == "" {
@ -698,7 +704,7 @@ func PrefixV4() (*netip.Prefix, error) {
return &prefixV4, nil
}
func PrefixV6() (*netip.Prefix, error) {
func prefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" {
@ -723,9 +729,12 @@ func PrefixV6() (*netip.Prefix, error) {
return &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
// LoadCLIConfig returns the needed configuration for the CLI client
// of Headscale to connect to a Headscale server.
func LoadCLIConfig() (*Config, error) {
return &Config{
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
UnixSocket: viper.GetString("unix_socket"),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
@ -735,15 +744,22 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
logConfig := GetLogConfig()
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
if err := validateServerConfig(); err != nil {
return nil, err
}
logConfig := logConfig()
zerolog.SetGlobalLevel(logConfig.Level)
prefix4, err := PrefixV4()
prefix4, err := prefixV4()
if err != nil {
return nil, err
}
prefix6, err := PrefixV6()
prefix6, err := prefixV6()
if err != nil {
return nil, err
}
@ -763,13 +779,13 @@ func GetHeadscaleConfig() (*Config, error) {
return nil, fmt.Errorf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
}
dnsConfig, err := DNS()
dnsConfig, err := dns()
if err != nil {
return nil, err
}
derpConfig := GetDERPConfig()
logTailConfig := GetLogTailConfig()
derpConfig := derpConfig()
logTailConfig := logtailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
oidcClientSecret := viper.GetString("oidc.client_secret")
@ -806,7 +822,7 @@ func GetHeadscaleConfig() (*Config, error) {
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
DisableUpdateCheck: false,
PrefixV4: prefix4,
PrefixV6: prefix6,
@ -823,11 +839,11 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout",
),
Database: GetDatabaseConfig(),
Database: databaseConfig(),
TLS: GetTLSConfig(),
TLS: tlsConfig(),
DNSConfig: DNSToTailcfgDNS(dnsConfig),
DNSConfig: dnsToTailcfgDNS(dnsConfig),
DNSUserNameInMagicDNS: dnsConfig.UserNameInMagicDNS,
ACMEEmail: viper.GetString("acme_email"),
@ -870,7 +886,7 @@ func GetHeadscaleConfig() (*Config, error) {
LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort,
Policy: GetPolicyConfig(),
Policy: policyConfig(),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
@ -890,10 +906,6 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]

View file

@ -1,6 +1,8 @@
package types
import (
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
@ -22,7 +24,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-config",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -48,12 +50,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: true,
@ -79,7 +81,7 @@ func TestReadConfig(t *testing.T) {
name: "unmarshal-dns-full-no-magic",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -105,12 +107,12 @@ func TestReadConfig(t *testing.T) {
name: "dns-to-tailcfg.DNSConfig",
configPath: "testdata/dns_full_no_magic.yaml",
setup: func(t *testing.T) (any, error) {
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
return DNSToTailcfgDNS(dns), nil
return dnsToTailcfgDNS(dns), nil
},
want: &tailcfg.DNSConfig{
Proxied: false,
@ -136,7 +138,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-in-server-url-err",
configPath: "testdata/base-domain-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
return GetHeadscaleConfig()
return LoadServerConfig()
},
want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
@ -145,7 +147,7 @@ func TestReadConfig(t *testing.T) {
name: "base-domain-not-in-server-url",
configPath: "testdata/base-domain-not-in-server-url.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -165,7 +167,7 @@ func TestReadConfig(t *testing.T) {
name: "policy-path-is-loaded",
configPath: "testdata/policy-path-is-loaded.yaml",
setup: func(t *testing.T) (any, error) {
cfg, err := GetHeadscaleConfig()
cfg, err := LoadServerConfig()
if err != nil {
return nil, err
}
@ -245,7 +247,7 @@ func TestReadConfigFromEnv(t *testing.T) {
setup: func(t *testing.T) (any, error) {
t.Logf("all settings: %#v", viper.AllSettings())
dns, err := DNS()
dns, err := dns()
if err != nil {
return nil, err
}
@ -289,3 +291,49 @@ func TestReadConfigFromEnv(t *testing.T) {
})
}
}
func TestTLSConfigValidation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
t.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
// Check configuration validation errors (1)
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
err = validateServerConfig()
assert.Error(t, err)
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"
"testing"
"time"
@ -735,13 +736,7 @@ func TestNodeTagCommand(t *testing.T) {
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
// try to set a wrong tag and retrieve the error
type errOutput struct {
Error string `json:"error"`
}
var errorOutput errOutput
err = executeAndUnmarshal(
headscale,
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
@ -750,10 +745,8 @@ func TestNodeTagCommand(t *testing.T) {
"-t", "wrong-tag",
"--output", "json",
},
&errorOutput,
)
assert.Nil(t, err)
assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'")
assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, len(machineKeys))
@ -1398,18 +1391,17 @@ func TestNodeRenameCommand(t *testing.T) {
assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5")
// Test failure for too long names
result, err := headscale.Execute(
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[4].GetId()),
"testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890",
strings.Repeat("t", 64),
},
)
assert.Nil(t, err)
assert.Contains(t, result, "not be over 63 chars")
assert.ErrorContains(t, err, "not be over 63 chars")
var listAllAfterRenameAttempt []v1.Node
err = executeAndUnmarshal(
@ -1536,7 +1528,7 @@ func TestNodeMoveCommand(t *testing.T) {
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
moveToNonExistingNSResult, err := headscale.Execute(
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
@ -1549,11 +1541,9 @@ func TestNodeMoveCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.Contains(
assert.ErrorContains(
t,
moveToNonExistingNSResult,
err,
"user not found",
)
assert.Equal(t, node.GetUser().GetName(), "new-user")