From a5e7147b69e235a22bdfec5532ae18a33160f130 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Tue, 14 Feb 2023 17:21:51 -0500 Subject: [PATCH] initial commit, began unit tests --- database/database.go | 3 + logic/enrollment_key.go | 133 +++++++++++++++++++++++++++++++++++ logic/enrollment_key_test.go | 70 ++++++++++++++++++ models/enrollment_key.go | 40 +++++++++++ 4 files changed, 246 insertions(+) create mode 100644 logic/enrollment_key.go create mode 100644 logic/enrollment_key_test.go create mode 100644 models/enrollment_key.go diff --git a/database/database.go b/database/database.go index 283a829b..38d4fd40 100644 --- a/database/database.go +++ b/database/database.go @@ -57,6 +57,8 @@ const ( CACHE_TABLE_NAME = "cache" // HOSTS_TABLE_NAME - the table name for hosts HOSTS_TABLE_NAME = "hosts" + // ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys + ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys" // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -138,6 +140,7 @@ func createTables() { createTable(USER_GROUPS_TABLE_NAME) createTable(CACHE_TABLE_NAME) createTable(HOSTS_TABLE_NAME) + createTable(ENROLLMENT_KEYS_TABLE_NAME) } func createTable(tableName string) error { diff --git a/logic/enrollment_key.go b/logic/enrollment_key.go new file mode 100644 index 00000000..02315b62 --- /dev/null +++ b/logic/enrollment_key.go @@ -0,0 +1,133 @@ +package logic + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/ncutils" +) + +// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages +var EnrollmentKeyErrors = struct { + InvalidCreate string + NoKeyFound string + InvalidKey string +}{ + InvalidCreate: "invalid enrollment key created", + NoKeyFound: "no enrollmentkey found", + InvalidKey: "invalid key provided", +} + +// CreateEnrollmentKey - creates a new enrollment key in db +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) { + newKeyID, err := getUniqueEnrollmentID() + if err != nil { + return nil, err + } + k = &models.EnrollmentKey{ + Value: newKeyID, + Expiration: time.Time{}, + UsesRemaining: 0, + Unlimited: unlimited, + Networks: []string{}, + Tags: []string{}, + } + if uses > 0 { + k.UsesRemaining = uses + } + if !expiration.IsZero() { + k.Expiration = expiration + } + if len(networks) > 0 { + k.Networks = networks + } + if len(tags) > 0 { + k.Tags = tags + } + if ok := k.Validate(); !ok { + return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate) + } + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + return +} + +// GetAllEnrollmentKeys - fetches all enrollment keys from DB +func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + var currentKeysList = make([]*models.EnrollmentKey, len(currentKeys)) + for k := range currentKeys { + currentKeysList = append(currentKeysList, currentKeys[k]) + } + return currentKeysList, nil +} + +// GetEnrollmentKey - fetches a single enrollment key +// returns nil and error if not found +func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + if key, ok := currentKeys[value]; ok { + return key, nil + } + return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound) +} + +// DeleteEnrollmentKey - delete's a given enrollment key by value +func DeleteEnrollmentKey(value string) error { + return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) +} + +// == private == + +func upsertEnrollmentKey(k *models.EnrollmentKey) error { + if k == nil { + return fmt.Errorf(EnrollmentKeyErrors.InvalidKey) + } + data, err := json.Marshal(k) + if err != nil { + return err + } + return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) +} + +func getUniqueEnrollmentID() (string, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return "", err + } + newID := ncutils.MakeRandomString(32) + for _, ok := currentKeys[newID]; !ok; { + newID = ncutils.MakeRandomString(32) + } + return newID, nil +} + +func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { + records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) + if err != nil { + if !database.IsEmptyRecord(err) { + return nil, err + } + } + currentKeys := make(map[string]*models.EnrollmentKey) + if len(records) > 0 { + for k := range records { + var currentKey models.EnrollmentKey + if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil { + continue + } + currentKeys[k] = ¤tKey + } + } + return currentKeys, nil +} diff --git a/logic/enrollment_key_test.go b/logic/enrollment_key_test.go new file mode 100644 index 00000000..caf110f4 --- /dev/null +++ b/logic/enrollment_key_test.go @@ -0,0 +1,70 @@ +package logic + +import ( + "testing" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/stretchr/testify/assert" +) + +func TestCreate_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + t.Run("Can_Not_Create_Key", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) + assert.Nil(t, newKey) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) + }) + t.Run("Can_Create_Key_Uses", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + assert.Nil(t, err) + assert.Equal(t, 1, newKey.UsesRemaining) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Time", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Networks) == 2) + }) + t.Run("Can_Create_Key_WithTags", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Tags) == 2) + }) + removeAllEnrollments() +} + +func TestDelete_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func TestDecrement_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func TestValidity_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func removeAllEnrollments() { + database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) +} diff --git a/models/enrollment_key.go b/models/enrollment_key.go new file mode 100644 index 00000000..0072aca0 --- /dev/null +++ b/models/enrollment_key.go @@ -0,0 +1,40 @@ +package models + +import "time" + +// EnrollmentKeyLength - the length of an enrollment key +const EnrollmentKeyLength = 32 + +// EnrollmentKey - the +type EnrollmentKey struct { + Expiration time.Time `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Value string `json:"value"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` +} + +// EnrollmentKey.IsValid - checks if the key is still valid to use +func (k *EnrollmentKey) IsValid() bool { + if k == nil { + return false + } + if k.UsesRemaining > 0 { + return true + } + if !k.Expiration.IsZero() && time.Now().After(k.Expiration) { + return true + } + + return k.Unlimited +} + +// EnrollmentKey.Validate - validate's an EnrollmentKey +// should be used during creation +func (k *EnrollmentKey) Validate() bool { + return k.Networks != nil && + k.Tags != nil && + len(k.Value) == EnrollmentKeyLength && + k.IsValid() +}