diff --git a/machine.go b/machine.go index 1bed2955..dda49020 100644 --- a/machine.go +++ b/machine.go @@ -374,7 +374,13 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { // SetTags takes a Machine struct pointer and update the forced tags. func (h *Headscale) SetTags(machine *Machine, tags []string) error { - machine.ForcedTags = tags + newTags := []string{} + for _, tag := range tags { + if !contains(newTags, tag) { + newTags = append(newTags, tag) + } + } + machine.ForcedTags = newTags if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) { return err } diff --git a/machine_test.go b/machine_test.go index a06d0db2..35c3eed9 100644 --- a/machine_test.go +++ b/machine_test.go @@ -280,6 +280,49 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { } } +func (s *Suite) TestSetTags(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + app.db.Save(machine) + + // assign simple tags + sTags := []string{"tag:test", "tag:foo"} + err = app.SetTags(machine, sTags) + c.Assert(err, check.IsNil) + machine, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) + + // assign duplicat tags, expect no errors but no doubles in DB + eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} + err = app.SetTags(machine, eTags) + c.Assert(err, check.IsNil) + machine, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert( + machine.ForcedTags, + check.DeepEquals, + StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + ) +} + func Test_getTags(t *testing.T) { type args struct { aclPolicy *ACLPolicy