From 56d5c85da7240c9561316595ab0dd35e99b8e193 Mon Sep 17 00:00:00 2001 From: abhishek9686 Date: Wed, 30 Oct 2024 15:58:55 +0400 Subject: [PATCH] block default key deletion,delete default key on network deletion --- controllers/enrollmentkeys.go | 3 ++- controllers/tags.go | 1 + logic/enrollmentkey.go | 29 ++++++++++++++++++++++++++--- logic/enrollmentkey_test.go | 30 +++++++++++++++--------------- logic/networks.go | 12 ++++++++++++ migrate/migrate.go | 1 + models/enrollment_key.go | 1 + 7 files changed, 58 insertions(+), 19 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 9d7fbe43..1ab9498e 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -72,7 +72,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { params := mux.Vars(r) keyID := params["keyID"] - err := logic.DeleteEnrollmentKey(keyID) + err := logic.DeleteEnrollmentKey(keyID, false) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -159,6 +159,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { enrollmentKeyBody.Groups, enrollmentKeyBody.Unlimited, relayId, + false, ) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) diff --git a/controllers/tags.go b/controllers/tags.go index 9b19aba3..633dab96 100644 --- a/controllers/tags.go +++ b/controllers/tags.go @@ -224,6 +224,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) { go func() { logic.RemoveDeviceTagFromAclPolicies(tag.ID, tag.Network) + logic.RemoveTagFromEnrollmentKeys(tag.ID) mq.PublishPeerUpdate(false) }() logic.ReturnSuccessResponse(w, r, "deleted tag "+tagID) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index bf811a1a..b479c302 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -37,7 +37,7 @@ var ( ) // CreateEnrollmentKey - creates a new enrollment key in db -func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey bool) (*models.EnrollmentKey, error) { newKeyID, err := getUniqueEnrollmentID() if err != nil { return nil, err @@ -152,11 +152,14 @@ func deleteEnrollmentkeyFromCache(key string) { } // DeleteEnrollmentKey - delete's a given enrollment key by value -func DeleteEnrollmentKey(value string) error { - _, err := GetEnrollmentKey(value) +func DeleteEnrollmentKey(value string, force bool) error { + key, err := GetEnrollmentKey(value) if err != nil { return err } + if key.Default && !force { + return errors.New("cannot delete default network key") + } err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) if err == nil { if servercfg.CacheEnabled() { @@ -311,3 +314,23 @@ func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) { } return currentKeys, nil } + +func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) { + keys, _ := GetAllEnrollmentKeys() + for _, key := range keys { + newTags := []models.TagID{} + update := false + for _, tagID := range key.Groups { + if tagID == deletedTagID { + update = true + continue + } + newTags = append(newTags, tagID) + } + if update { + key.Groups = newTags + upsertEnrollmentKey(&key) + } + + } +} diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 5e63df16..92b4c5e2 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() t.Run("Can_Not_Create_Key", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false) assert.Nil(t, newKey) assert.NotNil(t, err) assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false) assert.Nil(t, err) assert.Equal(t, 1, newKey.UsesRemaining) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Time", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Networks) == 2) }) t.Run("Can_Create_Key_WithTags", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Tags) == 2) @@ -62,10 +62,10 @@ func TestCreateEnrollmentKey(t *testing.T) { func TestDelete_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false) t.Run("Can_Delete_Key", func(t *testing.T) { assert.True(t, newKey.IsValid()) - err := DeleteEnrollmentKey(newKey.Value) + err := DeleteEnrollmentKey(newKey.Value, false) assert.Nil(t, err) oldKey, err := GetEnrollmentKey(newKey.Value) assert.Equal(t, oldKey, models.EnrollmentKey{}) @@ -73,7 +73,7 @@ func TestDelete_EnrollmentKey(t *testing.T) { assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { - err := DeleteEnrollmentKey("notakey") + err := DeleteEnrollmentKey("notakey", false) assert.NotNil(t, err) assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) @@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false) t.Run("Check_initial_uses", func(t *testing.T) { assert.True(t, newKey.IsValid()) assert.Equal(t, newKey.UsesRemaining, 1) @@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) { func TestUsability_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) - key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil) - key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false) t.Run("Check if valid use key can be used", func(t *testing.T) { assert.Equal(t, key1.UsesRemaining, 1) ok := TryToUseEnrollmentKey(key1) @@ -145,7 +145,7 @@ func removeAllEnrollments() { func TestTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false) const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" @@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { func TestDeTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false) const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" diff --git a/logic/networks.go b/logic/networks.go index 1a50fa79..1617889d 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -177,6 +177,17 @@ func DeleteNetwork(network string) error { if err != nil { logger.Log(1, "failed to remove the node acls during network delete for network,", network) } + // Delete default network enrollment key + keys, _ := GetAllEnrollmentKeys() + for _, key := range keys { + if key.Tags[0] == network { + if key.Default { + DeleteEnrollmentKey(key.Value, true) + break + } + + } + } nodeCount, err := GetNetworkNonServerNodeCount(network) if nodeCount == 0 || database.IsEmptyRecord(err) { // delete server nodes first then db records @@ -243,6 +254,7 @@ func CreateNetwork(network models.Network) (models.Network, error) { []models.TagID{}, true, uuid.Nil, + true, ) return network, nil diff --git a/migrate/migrate.go b/migrate/migrate.go index 36e4534c..1c697873 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -149,6 +149,7 @@ func updateEnrollmentKeys() { []models.TagID{}, true, uuid.Nil, + true, ) } diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 5aa89c8a..f133d755 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -53,6 +53,7 @@ type EnrollmentKey struct { Type KeyType `json:"type"` Relay uuid.UUID `json:"relay"` Groups []TagID `json:"groups"` + Default bool `json:"default"` } // APIEnrollmentKey - used to create enrollment keys via API