diff --git a/hscontrol/auth.go b/hscontrol/auth.go index aaab03ce..8b8557ba 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -66,7 +66,7 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { - logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey) + logInfo, logTrace, _ := logAuthFunc(regReq, machineKey) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) @@ -105,16 +105,6 @@ func (h *Headscale) handleRegister( logInfo("Node not found in database, creating new") - givenName, err := h.db.GenerateGivenName( - machineKey, - regReq.Hostinfo.Hostname, - ) - if err != nil { - logErr(err, "Failed to generate given name for node") - - return - } - // The node did not have a key to authenticate, which means // that we rely on a method that calls back some how (OpenID or CLI) // We create the node and then keep it around until a callback @@ -122,7 +112,6 @@ func (h *Headscale) handleRegister( newNode := types.Node{ MachineKey: machineKey, Hostname: regReq.Hostinfo.Hostname, - GivenName: givenName, NodeKey: regReq.NodeKey, LastSeen: &now, Expiry: &time.Time{}, @@ -354,21 +343,8 @@ func (h *Headscale) handleAuthKey( } else { now := time.Now().UTC() - givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname) - if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed to generate given name for node") - - return - } - nodeToRegister := types.Node{ Hostname: registerRequest.Hostinfo.Hostname, - GivenName: givenName, UserID: pak.User.ID, User: pak.User, MachineKey: machineKey, diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index a9e78a45..c0f42de1 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } -func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return getNode(rx, user, name) @@ -242,9 +228,9 @@ func SetTags( } // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. +// and renames it. If the name is not unique, it will return an error. func RenameNode(tx *gorm.DB, - nodeID uint64, newName string, + nodeID types.NodeID, newName string, ) error { err := util.CheckForFQDNRules( newName, @@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB, return fmt.Errorf("renaming node: %w", err) } + uniq, err := isUnqiueName(tx, newName) + if err != nil { + return fmt.Errorf("checking if name is unique: %w", err) + } + + if !uniq { + return fmt.Errorf("name is not unique: %s", newName) + } + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } @@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad node.IPv4 = ipv4 node.IPv6 = ipv6 + if node.GivenName == "" { + givenName, err := ensureUniqueGivenName(tx, node.Hostname) + if err != nil { + return nil, fmt.Errorf("failed to ensure unique given name: %w", err) + } + + node.GivenName = givenName + } + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } @@ -642,40 +646,32 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { return normalizedHostname, nil } -func (hsdb *HSDatabase) GenerateGivenName( - mkey key.MachinePublic, - suppliedName string, -) (string, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { - return GenerateGivenName(rx, mkey, suppliedName) - }) +func isUnqiueName(tx *gorm.DB, name string) (bool, error) { + nodes := types.Nodes{} + if err := tx. + Where("given_name = ?", name).Find(&nodes).Error; err != nil { + return false, err + } + + return len(nodes) == 0, nil } -func GenerateGivenName( +func ensureUniqueGivenName( tx *gorm.DB, - mkey key.MachinePublic, - suppliedName string, + name string, ) (string, error) { - givenName, err := generateGivenName(suppliedName, false) + givenName, err := generateGivenName(name, false) if err != nil { return "", err } - // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := listNodesByGivenName(tx, givenName) + unique, err := isUnqiueName(tx, givenName) if err != nil { return "", err } - var nodeFound *types.Node - for idx, node := range nodes { - if node.GivenName == givenName { - nodeFound = nodes[idx] - } - } - - if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() { - postfixedName, err := generateGivenName(suppliedName, true) + if !unique { + postfixedName, err := generateGivenName(name, true) if err != nil { return "", err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 94cce13b..bafb22ba 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -19,6 +19,7 @@ import ( "github.com/puzpuzpuz/xsync/v3" "github.com/stretchr/testify/assert" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/ptr" @@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) { c.Assert(nodeFromDB.IsExpired(), check.Equals, true) } -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("user-1", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machineKey2 := key.NewMachine() - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) - - givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") - comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1") - comment = check.Commentf("Same user, same node, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1") - comment = check.Commentf("Same user, unique nodes, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment) -} - func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) { assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) } + +func TestRenameNode(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser("test") + assert.NoError(t, err) + + user2, err := db.CreateUser("test2") + assert.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + err = db.DB.Save(&node).Error + assert.NoError(t, err) + + err = db.DB.Save(&node2).Error + assert.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNode(tx, node, nil, nil) + if err != nil { + return err + } + _, err = RegisterNode(tx, node2, nil, nil) + return err + }) + assert.NoError(t, err) + + nodes, err := db.ListNodes() + assert.NoError(t, err) + + assert.Len(t, nodes, 2) + + t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName) + t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName) + + assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName) + assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName) + assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname) + assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName) + assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname) + assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname) + assert.Len(t, nodes[0].Hostname, 4) + assert.Len(t, nodes[1].Hostname, 4) + assert.Len(t, nodes[0].GivenName, 4) + assert.Len(t, nodes[1].GivenName, 13) + + // Nodes can be renamed to a unique name + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "newname") + }) + assert.NoError(t, err) + + nodes, err = db.ListNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, nodes[0].Hostname, "test") + assert.Equal(t, nodes[0].GivenName, "newname") + + // Nodes can reuse name that is no longer used + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[1].ID, "test") + }) + assert.NoError(t, err) + + nodes, err = db.ListNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, nodes[0].Hostname, "test") + assert.Equal(t, nodes[0].GivenName, "newname") + assert.Equal(t, nodes[1].GivenName, "test") + + // Nodes cannot be renamed to used names + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "test") + }) + assert.ErrorContains(t, err, "name is not unique") +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 3f985d98..596748f2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode( node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { err := db.RenameNode( tx, - request.GetNodeId(), + types.NodeID(request.GetNodeId()), request.GetNewName(), ) if err != nil { @@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode( return nil, err } - givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName()) - if err != nil { - return nil, err - } - nodeKey := key.NewNode() newNode := types.Node{ MachineKey: mkey, NodeKey: nodeKey.Public(), Hostname: request.GetName(), - GivenName: givenName, User: *user, Expiry: &time.Time{},