diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index dc6669bd..4ae31e68 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -156,6 +156,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, + enrollmentKeyBody.Groups, enrollmentKeyBody.Unlimited, relayId, ) @@ -206,7 +207,7 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) { } } - newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId) + newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId, enrollmentKeyBody.Groups) if err != nil { slog.Error("failed to update enrollment key", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -307,6 +308,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { return } } + newHost.Tags = make(map[models.TagID]struct{}) + for _, tagI := range enrollmentKey.Groups { + newHost.Tags[tagI] = struct{}{} + } if err = logic.CreateHost(&newHost); err != nil { logger.Log( 0, @@ -337,6 +342,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { return } logic.UpdateHostFromClient(&newHost, currHost) + currHost.Tags = make(map[models.TagID]struct{}) + for _, tagI := range enrollmentKey.Groups { + currHost.Tags[tagI] = struct{}{} + } err = logic.UpsertHost(currHost) if err != nil { slog.Error("failed to update host", "id", currHost.ID, "error", err) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index d3c48a01..bf811a1a 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -37,7 +37,7 @@ var ( ) // CreateEnrollmentKey - creates a new enrollment key in db -func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { newKeyID, err := getUniqueEnrollmentID() if err != nil { return nil, err @@ -51,6 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string Tags: []string{}, Type: models.Undefined, Relay: relay, + Groups: groups, } if uses > 0 { k.UsesRemaining = uses @@ -89,7 +90,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string } // UpdateEnrollmentKey - updates an existing enrollment key's associated relay -func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) { +func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID) (*models.EnrollmentKey, error) { key, err := GetEnrollmentKey(keyId) if err != nil { return nil, err @@ -109,7 +110,7 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey } key.Relay = relayId - + key.Groups = groups if err = upsertEnrollmentKey(&key); err != nil { return nil, err } diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 677c4714..5e63df16 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() t.Run("Can_Not_Create_Key", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil) assert.Nil(t, newKey) assert.NotNil(t, err) assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) assert.Nil(t, err) assert.Equal(t, 1, newKey.UsesRemaining) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Time", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Networks) == 2) }) t.Run("Can_Create_Key_WithTags", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Tags) == 2) @@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) { func TestDelete_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) t.Run("Can_Delete_Key", func(t *testing.T) { assert.True(t, newKey.IsValid()) err := DeleteEnrollmentKey(newKey.Value) @@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) t.Run("Check_initial_uses", func(t *testing.T) { assert.True(t, newKey.IsValid()) assert.Equal(t, newKey.UsesRemaining, 1) @@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) { func TestUsability_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) - key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil) - key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil) + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) t.Run("Check if valid use key can be used", func(t *testing.T) { assert.Equal(t, key1.UsesRemaining, 1) ok := TryToUseEnrollmentKey(key1) @@ -145,7 +145,7 @@ func removeAllEnrollments() { func TestTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" @@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { func TestDeTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" diff --git a/models/enrollment_key.go b/models/enrollment_key.go index e775344d..5aa89c8a 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -52,6 +52,7 @@ type EnrollmentKey struct { Token string `json:"token,omitempty"` // B64 value of EnrollmentToken Type KeyType `json:"type"` Relay uuid.UUID `json:"relay"` + Groups []TagID `json:"groups"` } // APIEnrollmentKey - used to create enrollment keys via API @@ -63,6 +64,7 @@ type APIEnrollmentKey struct { Tags []string `json:"tags" validate:"required,dive,min=3,max=32"` Type KeyType `json:"type"` Relay string `json:"relay"` + Groups []TagID `json:"groups"` } // RegisterResponse - the response to a successful enrollment register