Go format with shorter lines

This commit is contained in:
Kristoffer Dalby 2021-11-13 08:36:45 +00:00
parent edfcdc466c
commit 03b7ec62ca
35 changed files with 794 additions and 192 deletions

View file

@ -108,7 +108,9 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u)
}
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
func (h *Headscale) generateACLPolicyDestPorts(
d string,
) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat

View file

@ -22,7 +22,11 @@ func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
func (s *Suite) TestParseHosts(c *check.C) {
var hs Hosts
err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
err := hs.UnmarshalJSON(
[]byte(
`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
),
)
c.Assert(hs, check.NotNil)
c.Assert(err, check.IsNil)
}

55
api.go
View file

@ -95,7 +95,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).
Inc()
return
}
m = &newMachine
@ -156,11 +157,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).
Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
@ -195,11 +198,13 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "")
return
}
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).
Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
@ -234,7 +239,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
mKey.HexString(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
@ -257,7 +266,11 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.Data(200, "application/json; charset=utf-8", respBody)
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
func (h *Headscale) getMapResponse(
mKey wgkey.Key,
req tailcfg.MapRequest,
m *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
@ -291,7 +304,12 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
return nil, err
}
dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers)
dnsConfig, err := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*m,
peers,
)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@ -340,7 +358,11 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
return data, nil
}
func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
func (h *Headscale) getMapKeepAliveResponse(
mKey wgkey.Key,
req tailcfg.MapRequest,
m *Machine,
) ([]byte, error) {
resp := tailcfg.MapResponse{
KeepAlive: true,
}
@ -394,7 +416,8 @@ func (h *Headscale) handleAuthKey(
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc()
return
}
c.Data(401, "application/json; charset=utf-8", respBody)
@ -402,7 +425,8 @@ func (h *Headscale) handleAuthKey(
Str("func", "handleAuthKey").
Str("machine", m.Name).
Msg("Failed authentication via AuthKey")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc()
return
}
@ -416,7 +440,8 @@ func (h *Headscale) handleAuthKey(
Str("func", "handleAuthKey").
Str("machine", m.Name).
Msg("Failed to find an available IP")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc()
return
}
log.Info().
@ -445,11 +470,13 @@ func (h *Headscale) handleAuthKey(
Str("machine", m.Name).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).
Inc()
c.String(http.StatusInternalServerError, "Extremely sad!")
return
}
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc()
machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).
Inc()
c.Data(200, "application/json; charset=utf-8", respBody)
log.Info().
Str("func", "handleAuthKey").

75
app.go
View file

@ -152,8 +152,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
var dbString string
switch cfg.DBtype {
case "postgres":
dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
dbString = fmt.Sprintf(
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
cfg.DBhost,
cfg.DBport,
cfg.DBname,
cfg.DBuser,
cfg.DBpass,
)
case "sqlite3":
dbString = cfg.DBpath
default:
@ -182,7 +188,10 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
}
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
magicDNSDomains, err := generateMagicDNSRootDomains(
h.cfg.IPPrefix,
h.cfg.BaseDomain,
)
if err != nil {
return nil, err
}
@ -224,7 +233,10 @@ func (h *Headscale) expireEphemeralNodesWorker() {
for _, ns := range namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name)
if err != nil {
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
log.Error().
Err(err).
Str("namespace", ns.Name).
Msg("Error listing machines in namespace")
return
}
@ -232,7 +244,9 @@ func (h *Headscale) expireEphemeralNodesWorker() {
for _, m := range machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
log.Info().
Str("machine", m.Name).
Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error
if err != nil {
@ -274,18 +288,33 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// the server
p, _ := peer.FromContext(ctx)
log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate")
log.Trace().
Caller().
Str("client_address", p.Addr.String()).
Msg("Client is trying to authenticate")
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed")
return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Msg("Retrieving metadata is failed")
return ctx, status.Errorf(
codes.InvalidArgument,
"Retrieving metadata is failed",
)
}
authHeader, ok := md["authorization"]
if !ok {
log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied")
return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied")
log.Error().
Caller().
Str("client_address", p.Addr.String()).
Msg("Authorization token is not supplied")
return ctx, status.Errorf(
codes.Unauthenticated,
"Authorization token is not supplied",
)
}
token := authHeader[0]
@ -295,7 +324,10 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
Caller().
Str("client_address", p.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
return ctx, status.Error(
codes.Unauthenticated,
`missing "Bearer " prefix in "Authorization" header`,
)
}
// TODO(kradalby): Implement API key backend:
@ -307,7 +339,10 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
// and API key auth
return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet")
return ctx, status.Error(
codes.Unauthenticated,
"Authentication is not implemented yet",
)
//if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
// log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
@ -405,7 +440,10 @@ func (h *Headscale) Serve() error {
// Match gRPC requests here
grpcListener := m.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"),
cmux.HTTP2MatchHeaderFieldSendSettings(
"content-type",
"application/grpc+proto",
),
)
// Otherwise match regular http requests.
httpListener := m.Match(cmux.Any())
@ -436,7 +474,10 @@ func (h *Headscale) Serve() error {
p := ginprometheus.NewPrometheus("gin")
p.Use(r)
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) })
r.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
r.GET("/key", h.KeyHandler)
r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler)
@ -537,7 +578,8 @@ func (h *Headscale) Serve() error {
g.Go(func() error { return m.Serve() })
log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
return g.Wait()
}
@ -545,7 +587,8 @@ func (h *Headscale) Serve() error {
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
log.Warn().
Msg("Listening with TLS but ServerURL does not start with https://")
}
m := autocert.Manager{

View file

@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
type Suite struct{}
var tmpDir string
var h Headscale
var (
tmpDir string
h Headscale
)
func (s *Suite) SetUpTest(c *check.C) {
s.ResetDB(c)

View file

@ -73,7 +73,11 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple index template"),
)
return
}
@ -89,7 +93,11 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return
}
@ -99,7 +107,11 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Failed to create UUID"),
)
return
}
@ -117,7 +129,11 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple macOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple macOS template"),
)
return
}
case "ios":
@ -126,11 +142,19 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple iOS template"),
)
return
}
default:
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"))
c.Data(
http.StatusOK,
"text/html; charset=utf-8",
[]byte("Invalid platform, only ios and macos is supported"),
)
return
}
@ -146,11 +170,19 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Apple platform template"),
)
return
}
c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes())
c.Data(
http.StatusOK,
"application/x-apple-aspen-config; charset=utf-8",
content.Bytes(),
)
}
type AppleMobileConfig struct {
@ -164,7 +196,8 @@ type AppleMobilePlatformConfig struct {
Url string
}
var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
var commonTemplate = template.Must(
template.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
@ -187,7 +220,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`<?xml ver
{{.Payload}}
</array>
</dict>
</plist>`))
</plist>`),
)
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
<dict>

View file

@ -28,7 +28,10 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name)
m2, err := h.RegisterMachine(
"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
n.Name,
)
c.Assert(err, check.IsNil)
c.Assert(m2.Registered, check.Equals, true)

View file

@ -27,7 +27,8 @@ func init() {
if err != nil {
log.Fatal().Err(err).Msg("")
}
createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
createNodeCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
debugCmd.AddCommand(createNodeCmd)
}
@ -56,19 +57,31 @@ var createNodeCmd = &cobra.Command{
name, err := cmd.Flags().GetString("name")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return
}
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
@ -81,7 +94,11 @@ var createNodeCmd = &cobra.Command{
response, err := client.DebugCreateMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
output,
)
return
}

View file

@ -48,7 +48,14 @@ var createNamespaceCmd = &cobra.Command{
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
response, err := client.CreateNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot create namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}
@ -78,7 +85,14 @@ var destroyNamespaceCmd = &cobra.Command{
response, err := client.DeleteNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot destroy namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}
@ -100,7 +114,11 @@ var listNamespacesCmd = &cobra.Command{
response, err := client.ListNamespaces(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
output,
)
return
}
@ -122,7 +140,11 @@ var listNamespacesCmd = &cobra.Command{
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -151,7 +173,14 @@ var renameNamespaceCmd = &cobra.Command{
response, err := client.RenameNamespace(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename namespace: %s",
status.Convert(err).Message(),
),
output,
)
return
}

View file

@ -86,7 +86,11 @@ var registerNodeCmd = &cobra.Command{
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine key from flag: %s", err),
output,
)
return
}
@ -97,7 +101,14 @@ var registerNodeCmd = &cobra.Command{
response, err := client.RegisterMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
@ -126,7 +137,11 @@ var listNodesCmd = &cobra.Command{
response, err := client.ListMachines(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
@ -143,7 +158,11 @@ var listNodesCmd = &cobra.Command{
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -157,7 +176,11 @@ var deleteNodeCmd = &cobra.Command{
id, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
@ -171,7 +194,14 @@ var deleteNodeCmd = &cobra.Command{
getResponse, err := client.GetMachine(ctx, getRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node node: %s",
status.Convert(err).Message(),
),
output,
)
return
}
@ -183,7 +213,10 @@ var deleteNodeCmd = &cobra.Command{
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name),
Message: fmt.Sprintf(
"Do you want to remove the node %s?",
getResponse.GetMachine().Name,
),
}
err = survey.AskOne(prompt, &confirm)
if err != nil {
@ -198,10 +231,21 @@ var deleteNodeCmd = &cobra.Command{
return
}
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Error deleting node: %s",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output)
SuccessOutput(
map[string]string{"Result": "Node deleted"},
"Node deleted",
output,
)
} else {
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
}
@ -235,7 +279,11 @@ func sharingWorker(
machineResponse, err := client.GetMachine(ctx, machineRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err
}
@ -245,7 +293,11 @@ func sharingWorker(
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
output,
)
return "", nil, nil, err
}
@ -258,7 +310,11 @@ var shareMachineCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return
}
@ -273,7 +329,11 @@ var shareMachineCmd = &cobra.Command{
response, err := client.ShareMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
output,
)
return
}
@ -287,7 +347,11 @@ var unshareMachineCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, machine, namespace, err := sharingWorker(cmd, args)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
output,
)
return
}
@ -302,7 +366,11 @@ var unshareMachineCmd = &cobra.Command{
response, err := client.UnshareMachine(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
output,
)
return
}
@ -310,8 +378,22 @@ var unshareMachineCmd = &cobra.Command{
},
}
func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) {
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
func nodesToPtables(
currentNamespace string,
machines []*v1.Machine,
) (pterm.TableData, error) {
d := pterm.TableData{
{
"ID",
"Name",
"NodeKey",
"Namespace",
"IP address",
"Ephemeral",
"Last seen",
"Online",
},
}
for _, machine := range machines {
var ephemeral bool
@ -331,7 +413,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
nodeKey := tailcfg.NodeKey(nKey)
var online string
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
if lastSeen.After(
time.Now().Add(-5 * time.Minute),
) { // TODO: Find a better way to reliably show if online
online = pterm.LightGreen("true")
} else {
online = pterm.LightRed("false")

View file

@ -22,8 +22,10 @@ func init() {
preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.PersistentFlags().
Bool("reusable", false, "Make the preauthkey reusable")
createPreAuthKeyCmd.PersistentFlags().
Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags().
DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
}
@ -55,7 +57,11 @@ var listPreAuthKeys = &cobra.Command{
response, err := client.ListPreAuthKeys(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return
}
@ -64,7 +70,9 @@ var listPreAuthKeys = &cobra.Command{
return
}
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
d := pterm.TableData{
{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
}
for _, k := range response.PreAuthKeys {
expiration := "-"
if k.GetExpiration() != nil {
@ -91,7 +99,11 @@ var listPreAuthKeys = &cobra.Command{
}
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -139,7 +151,11 @@ var createPreAuthKeyCmd = &cobra.Command{
response, err := client.CreatePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return
}
@ -175,7 +191,11 @@ var expirePreAuthKeyCmd = &cobra.Command{
response, err := client.ExpirePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return
}

View file

@ -10,7 +10,8 @@ import (
func init() {
rootCmd.PersistentFlags().
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
rootCmd.PersistentFlags().
Bool("force", false, "Disable prompts and forces the execution")
}
var rootCmd = &cobra.Command{

View file

@ -21,7 +21,8 @@ func init() {
}
routesCmd.AddCommand(listRoutesCmd)
enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = enableRouteCmd.MarkFlagRequired("identifier")
if err != nil {
@ -46,7 +47,11 @@ var listRoutesCmd = &cobra.Command{
machineId, err := cmd.Flags().GetUint64("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
@ -60,7 +65,11 @@ var listRoutesCmd = &cobra.Command{
response, err := client.GetMachineRoute(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
output,
)
return
}
@ -77,7 +86,11 @@ var listRoutesCmd = &cobra.Command{
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
@ -95,13 +108,21 @@ omit the route you do not want to enable.
output, _ := cmd.Flags().GetString("output")
machineId, err := cmd.Flags().GetUint64("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting machine id from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
@ -116,7 +137,14 @@ omit the route you do not want to enable.
response, err := client.EnableMachineRoutes(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register machine: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
@ -133,7 +161,11 @@ omit the route you do not want to enable.
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},

View file

@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers")
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers))
restrictedResolvers := make(
[]dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
@ -219,7 +224,9 @@ func getHeadscaleConfig() headscale.Config {
"10h",
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
maxMachineRegistrationDuration = viper.GetDuration(
"max_machine_registration_duration",
)
}
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
@ -229,7 +236,9 @@ func getHeadscaleConfig() headscale.Config {
"8h",
) // use 8h here because it's the length of a standard business day
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
defaultMachineRegistrationDuration = viper.GetDuration(
"default_machine_registration_duration",
)
}
dnsConfig, baseDomain := GetDNSConfig()
@ -244,7 +253,9 @@ func getHeadscaleConfig() headscale.Config {
DERP: derpConfig,
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")),
@ -254,9 +265,11 @@ func getHeadscaleConfig() headscale.Config {
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")),
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
@ -431,7 +444,10 @@ type tokenAuth struct {
}
// Return value is mapped to request headers.
func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) {
func (t tokenAuth) GetRequestMetadata(
ctx context.Context,
in ...string,
) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil

View file

@ -63,7 +63,8 @@ func main() {
}
if !viper.GetBool("disable_check_updates") && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
cli.Version != "dev" {
githubTag := &latest.GithubTag{
Owner: "juanfont",
Repository: "headscale",

View file

@ -40,7 +40,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil {
c.Fatal(err)
}
@ -74,7 +77,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
err = os.Symlink(
filepath.Clean(path+"/../../config-example.yaml"),
filepath.Join(tmpDir, "config.yaml"),
)
if err != nil {
c.Fatal(err)
}
@ -128,7 +134,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*")
c.Assert(
tmp,
check.Matches,
".*Fatal config error: server_url must start with https:// or http://.*",
)
fmt.Println(tmp)
// Check configuration validation errors (2)

5
db.go
View file

@ -87,7 +87,10 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
// getValue returns the value for the given key in KV
func (h *Headscale) getValue(key string) (string, error) {
var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&row, "key = ?", key); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return "", errors.New("not found")
}
return row.Value, nil

17
dns.go
View file

@ -30,7 +30,10 @@ import (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
func generateMagicDNSRootDomains(
ipPrefix netaddr.IPPrefix,
baseDomain string,
) ([]dnsname.FQDN, error) {
// TODO(juanfont): we are not handing out IPv6 addresses yet
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
@ -69,12 +72,20 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
return fqdns, nil
}
func getMapResponseDNSConfig(dnsConfigOrig *tailcfg.DNSConfig, baseDomain string, m Machine, peers Machines) (*tailcfg.DNSConfig, error) {
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
m Machine,
peers Machines,
) (*tailcfg.DNSConfig, error) {
var dnsConfig *tailcfg.DNSConfig
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
dnsConfig = dnsConfigOrig.Clone()
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain))
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf("%s.%s", m.Namespace.Name, baseDomain),
)
namespaceSet := set.New(set.ThreadSafe)
namespaceSet.Add(m.Namespace)

View file

@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context,
request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) {
log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine")
log.Trace().
Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()).
Msg("Registering machine")
machine, err := api.h.RegisterMachine(
request.GetKey(),
request.GetNamespace(),
@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
return nil, err
}
sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace())
sharedMachines, err := api.h.ListSharedMachinesInNamespace(
request.GetNamespace(),
)
if err != nil {
return nil, err
}
@ -338,7 +343,11 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err
}
log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("")
log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,

View file

@ -109,7 +109,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
}
}
func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
func (s *IntegrationCLITestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats
}
@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
assert.True(
s.T(),
listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
assert.True(
s.T(),
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
// Expire three keys
for i := 0; i < 3; i++ {
@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
assert.Nil(s.T(), err)
assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
)
assert.True(
s.T(),
listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
)
}
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlySharedMachineNamespace []v1.Machine
err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace)
err = json.Unmarshal(
[]byte(listOnlySharedMachineNamespaceResult),
&listOnlySharedMachineNamespace,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterDelete []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterDeleteResult),
&listOnlyMachineNamespaceAfterDelete,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterShare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterShareResult),
&listOnlyMachineNamespaceAfterShare,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare)
err = json.Unmarshal(
[]byte(listOnlyMachineNamespaceAfterUnshareResult),
&listOnlyMachineNamespaceAfterUnshare,
)
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
)
assert.Nil(s.T(), err)
assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node")
assert.Contains(
s.T(),
string(failEnableNonAdvertisedRoute),
"route (route-machine) is not available on node",
)
}

View file

@ -12,7 +12,11 @@ import (
"github.com/ory/dockertest/v3/docker"
)
func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) {
func ExecuteCommand(
resource *dockertest.Resource,
cmd []string,
env []string,
) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer

View file

@ -89,7 +89,10 @@ func TestIntegrationTestSuite(t *testing.T) {
}
}
func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error {
func (s *IntegrationTestSuite) saveLog(
resource *dockertest.Resource,
basePath string,
) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
@ -118,12 +121,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644)
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644)
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
[]byte(stdout.String()),
0o644,
)
if err != nil {
return err
}
@ -144,14 +155,27 @@ func (s *IntegrationTestSuite) tailscaleContainer(
},
},
}
hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier)
hostname := fmt.Sprintf(
"%s-tailscale-%s-%s",
namespace,
strings.Replace(version, ".", "-", -1),
identifier,
)
tailscaleOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{&s.network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Cmd: []string{
"tailscaled",
"--tun=userspace-networking",
"--socks5-server=localhost:1055",
},
}
pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy)
pts, err := s.pool.BuildAndRunWithBuildOptions(
tailscaleBuildOptions,
tailscaleOptions,
DockerRestartPolicy,
)
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
for i := 0; i < scales.count; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)]
hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version)
hostname, container := s.tailscaleContainer(
namespace,
fmt.Sprint(i),
version,
)
scales.tailscales[hostname] = *container
}
}
@ -273,7 +301,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
headscaleEndpoint := "http://headscale:8080"
fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint)
fmt.Printf(
"Joining tailscale containers to headscale at %s\n",
headscaleEndpoint,
)
for hostname, tailscale := range scales.tailscales {
command := []string{
"tailscale",
@ -307,7 +338,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
func (s *IntegrationTestSuite) TearDownSuite() {
}
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
func (s *IntegrationTestSuite) HandleStats(
suiteName string,
stats *suite.SuiteInformation,
) {
s.stats = stats
}
@ -427,7 +461,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -449,7 +489,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
result, err := ExecuteCommand(
&s.headscale,
[]string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"},
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
"--namespace",
"shared",
},
[]string{},
)
assert.Nil(s.T(), err)
@ -520,7 +568,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
fmt.Printf(
"Pinging from %s (%s) to %s (%s)\n",
hostname,
mainIps[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -578,9 +632,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"PUT",
"--upload-file",
fmt.Sprintf("/tmp/file_from_%s", hostname),
fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname),
fmt.Sprintf(
"%s/v0/put/file_from_%s",
peerAPI,
hostname,
),
}
fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Sending file from %s (%s) to %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
_, err = ExecuteCommand(
&tailscale,
command,
@ -621,7 +685,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"ls",
fmt.Sprintf("/tmp/file_from_%s", peername),
}
fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip)
fmt.Printf(
"Checking file in %s (%s) from %s (%s)\n",
hostname,
ips[hostname],
peername,
ip,
)
result, err := ExecuteCommand(
&tailscale,
command,
@ -629,7 +699,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result)
assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername))
assert.Equal(
t,
result,
fmt.Sprintf("/tmp/file_from_%s\n", peername),
)
}
})
}
@ -699,7 +773,9 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
return ips, nil
}
func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) {
func getAPIURLs(
tailscales map[string]dockertest.Resource,
) (map[netaddr.IP]string, error) {
fts := make(map[netaddr.IP]string)
for _, tailscale := range tailscales {
command := []string{

View file

@ -73,8 +73,12 @@ func (m Machine) isExpired() bool {
func (h *Headscale) updateMachineExpiry(m *Machine) {
if m.isExpired() {
now := time.Now().UTC()
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
maxExpiry := now.Add(
h.cfg.MaxMachineRegistrationDuration,
) // calculate the maximum expiry
defaultExpiry := now.Add(
h.cfg.DefaultMachineRegistrationDuration,
) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) {
@ -157,7 +161,9 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines {
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
namespaceMachines, err := h.ListMachinesInNamespace(
sharedMachine.Namespace.Name,
)
if err != nil {
return Machines{}, err
}
@ -392,7 +398,11 @@ func (ms Machines) toNodes(
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS
func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
func (m Machine) toNode(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey)
if err != nil {
return nil, err
@ -425,7 +435,10 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
addrs = append(addrs, ip) // missing the ipv6 ?
allowedIPs := []netaddr.IPPrefix{}
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
allowedIPs = append(
allowedIPs,
ip,
) // we append the node own IP, as it is required by the clients
if includeRoutes {
routesStr := []string{}
@ -571,7 +584,10 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
}
m := Machine{}
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errors.New("Machine not found")
}
@ -693,7 +709,11 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
for _, newRoute := range newRoutes {
if !containsIpPrefix(availableRoutes, newRoute) {
return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
return fmt.Errorf(
"route (%s) is not available on node %s",
m.Name,
newRoute,
)
}
}

View file

@ -32,7 +32,7 @@ var (
Name: "update_request_sent_to_node_total",
Help: "The number of calls/messages issued on a specific nodes update channel",
}, []string{"namespace", "machine", "status"})
//TODO(kradalby): This is very debugging, we might want to remove it.
// TODO(kradalby): This is very debugging, we might want to remove it.
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_received_on_channel_total",

View file

@ -102,7 +102,10 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
// GetNamespace fetches a namespace by name
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&n, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errorNamespaceNotFound
}
return &n, nil
@ -144,7 +147,9 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
machines := []Machine{}
for _, sharedMachine := range sharedMachines {
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
machine, err := h.GetMachineByID(
sharedMachine.MachineID,
) // otherwise not everything comes filled
if err != nil {
return nil, err
}
@ -173,7 +178,10 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
v, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil {
return err
}
@ -182,7 +190,10 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
names := []string{}
err = json.Unmarshal([]byte(v), &names)
if err != nil {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
err = h.setValue(
"namespaces_pending_updates",
fmt.Sprintf(`["%s"]`, namespace.Name),
)
if err != nil {
return err
}

38
oidc.go
View file

@ -39,8 +39,11 @@ func (h *Headscale) initOIDC() error {
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
@ -127,7 +130,10 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
c.String(
http.StatusBadRequest,
fmt.Sprintf("Failed to decode id token claims: %s", err),
)
return
}
@ -135,7 +141,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
log.Error().Msg("requested machine state key expired before authorisation completed")
log.Error().
Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired")
return
}
@ -151,7 +158,10 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
m, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database")
c.String(
http.StatusInternalServerError,
"could not get machine info from database",
)
return
}
@ -168,15 +178,22 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
ns, err = h.CreateNamespace(nsName)
if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
log.Error().
Msgf("could not create new namespace '%s'", claims.Email)
c.String(
http.StatusInternalServerError,
"could not create new namespace",
)
return
}
}
ip, err := h.getAvailableIP()
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
c.String(
http.StatusInternalServerError,
"could not get an IP from the pool",
)
return
}
@ -209,7 +226,10 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
Str("username", claims.Username).
Str("machine", m.Name).
Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
c.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)
}
// getNamespaceFromEmail passes the users email through a list of "matchers"

View file

@ -164,10 +164,18 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
}
got, got1 := h.getNamespaceFromEmail(tt.args.email)
if got != tt.want {
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
t.Errorf(
"Headscale.getNamespaceFromEmail() got = %v, want %v",
got,
tt.want,
)
}
if got1 != tt.want1 {
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
t.Errorf(
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
got1,
tt.want1,
)
}
})
}

37
poll.go
View file

@ -158,7 +158,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc()
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").
Inc()
go func() { updateChan <- struct{}{} }()
return
} else if req.OmitPeers && req.Stream {
@ -184,10 +185,20 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc()
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").
Inc()
go func() { updateChan <- struct{}{} }()
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
h.PollNetMapStream(
c,
m,
req,
mKey,
pollDataChan,
keepAliveChan,
updateChan,
cancelKeepAlive,
)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
@ -260,7 +271,8 @@ func (h *Headscale) PollNetMapStream(
now := time.Now().UTC()
m.LastSeen = &now
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).
Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
@ -324,7 +336,8 @@ func (h *Headscale) PollNetMapStream(
Str("machine", m.Name).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc()
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).
Inc()
if h.isOutdated(m) {
log.Debug().
Str("handler", "PollNetMapStream").
@ -349,7 +362,8 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc()
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").
Inc()
return false
}
log.Trace().
@ -357,7 +371,8 @@ func (h *Headscale) PollNetMapStream(
Str("machine", m.Name).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc()
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").
Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
@ -377,7 +392,8 @@ func (h *Headscale) PollNetMapStream(
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).
Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
@ -424,7 +440,7 @@ func (h *Headscale) PollNetMapStream(
Str("machine", m.Name).
Str("channel", "Done").
Msg("Closing update channel")
//h.closeUpdateChannel(m)
// h.closeUpdateChannel(m)
close(updateChan)
log.Trace().
@ -483,7 +499,8 @@ func (h *Headscale) scheduledPollWorker(
Str("func", "scheduledPollWorker").
Str("machine", m.Name).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc()
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").
Inc()
updateChan <- struct{}{}
}
}

View file

@ -105,7 +105,10 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
// If returns no error and a PreAuthKey, it can be used
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{}
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, errorAuthKeyNotFound
}

View file

@ -11,7 +11,10 @@ import (
// Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
func (h *Headscale) GetAdvertisedNodeRoutes(
namespace string,
nodeName string,
) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
@ -27,7 +30,10 @@ func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (
// Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
func (h *Headscale) GetEnabledNodeRoutes(
namespace string,
nodeName string,
) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
@ -58,7 +64,11 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
// Deprecated: use machine function instead
// IsNodeRouteEnabled checks if a certain route has been enabled
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
func (h *Headscale) IsNodeRouteEnabled(
namespace string,
nodeName string,
routeStr string,
) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
@ -80,7 +90,11 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
// Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
func (h *Headscale) EnableNodeRoute(
namespace string,
nodeName string,
routeStr string,
) error {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return err

View file

@ -93,7 +93,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
}
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
availableRoutes, err := h.GetAdvertisedNodeRoutes(
"test",
"test_enable_route_machine",
)
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)

View file

@ -2,9 +2,11 @@ package headscale
import "gorm.io/gorm"
const errorSameNamespace = Error("Destination namespace same as origin")
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
const errorMachineNotShared = Error("Machine not shared to this namespace")
const (
errorSameNamespace = Error("Destination namespace same as origin")
errorMachineAlreadyShared = Error("Node already shared to this namespace")
errorMachineNotShared = Error("Machine not shared to this namespace")
)
// SharedMachine is a join table to support sharing nodes between namespaces
type SharedMachine struct {
@ -48,7 +50,9 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
}
sharedMachine := SharedMachine{}
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine)
result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).
Unscoped().
Delete(&sharedMachine)
if result.Error != nil {
return result.Error
}

View file

@ -4,7 +4,10 @@ import (
"gopkg.in/check.v1"
)
func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) {
func CreateNodeNamespace(
c *check.C,
namespace, node, key, IP string,
) (*Namespace, *Machine) {
n1, err := h.CreateNamespace(namespace)
c.Assert(err, check.IsNil)
@ -229,7 +232,11 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(
len(p1sAfter),
check.Equals,
2,
) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)

View file

@ -53,7 +53,11 @@ func SwaggerUI(c *gin.Context) {
Caller().
Err(err).
Msg("Could not render Swagger")
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"))
c.Data(
http.StatusInternalServerError,
"text/html; charset=utf-8",
[]byte("Could not render Swagger"),
)
return
}

View file

@ -25,11 +25,21 @@ type Error string
func (e Error) Error() string { return string(e) }
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
func decode(
msg []byte,
v interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
return decodeMsg(msg, v, pubKey, privKey)
}
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
func decodeMsg(
msg []byte,
v interface{},
pubKey *wgkey.Key,
privKey *wgkey.Private,
) error {
decrypted, err := decryptMsg(msg, pubKey, privKey)
if err != nil {
return err
@ -156,7 +166,11 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
}
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {