From db4ea9faa43aea49ec4565162ea8e2481d3f625d Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Wed, 15 Feb 2023 15:27:26 -0500 Subject: [PATCH] completed crud unit tests --- config/config_test.go | 8 ++-- logic/accesskeys_test.go | 2 - logic/{enrollment_key.go => enrollmentkey.go} | 40 +++++++++++++--- ...ment_key_test.go => enrollmentkey_test.go} | 46 ++++++++++++++++--- logic/host_test.go | 21 +-------- models/enrollment_key.go | 6 ++- 6 files changed, 82 insertions(+), 41 deletions(-) rename logic/{enrollment_key.go => enrollmentkey.go} (78%) rename logic/{enrollment_key_test.go => enrollmentkey_test.go} (56%) diff --git a/config/config_test.go b/config/config_test.go index 2a8205fa..04c2c144 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,7 @@ -//Environment file for getting variables -//Currently the only thing it does is set the master password -//Should probably have it take over functions from OS such as port and mongodb connection details -//Reads from the config/environments/dev.yaml file by default +// Environment file for getting variables +// Currently the only thing it does is set the master password +// Should probably have it take over functions from OS such as port and mongodb connection details +// Reads from the config/environments/dev.yaml file by default package config import ( diff --git a/logic/accesskeys_test.go b/logic/accesskeys_test.go index 030508a9..fe5443fc 100644 --- a/logic/accesskeys_test.go +++ b/logic/accesskeys_test.go @@ -5,7 +5,6 @@ import "testing" func Test_genKeyName(t *testing.T) { for i := 0; i < 100; i++ { kname := genKeyName() - t.Log(kname) if len(kname) != 20 { t.Fatalf("improper length of key name, expected 20 got :%d", len(kname)) } @@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) { func Test_genKey(t *testing.T) { for i := 0; i < 100; i++ { kname := GenKey() - t.Log(kname) if len(kname) != 16 { t.Fatalf("improper length of key name, expected 16 got :%d", len(kname)) } diff --git a/logic/enrollment_key.go b/logic/enrollmentkey.go similarity index 78% rename from logic/enrollment_key.go rename to logic/enrollmentkey.go index 02315b62..4dc021fa 100644 --- a/logic/enrollment_key.go +++ b/logic/enrollmentkey.go @@ -12,13 +12,15 @@ import ( // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages var EnrollmentKeyErrors = struct { - InvalidCreate string - NoKeyFound string - InvalidKey string + InvalidCreate string + NoKeyFound string + InvalidKey string + NoUsesRemaining string }{ - InvalidCreate: "invalid enrollment key created", - NoKeyFound: "no enrollmentkey found", - InvalidKey: "invalid key provided", + InvalidCreate: "invalid enrollment key created", + NoKeyFound: "no enrollmentkey found", + InvalidKey: "invalid key provided", + NoUsesRemaining: "no uses remaining", } // CreateEnrollmentKey - creates a new enrollment key in db @@ -84,9 +86,30 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { // DeleteEnrollmentKey - delete's a given enrollment key by value func DeleteEnrollmentKey(value string) error { + _, err := GetEnrollmentKey(value) + if err != nil { + return err + } return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) } +// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining +func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { + k, err := GetEnrollmentKey(value) + if err != nil { + return nil, err + } + if k.UsesRemaining == 0 { + return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining) + } + k.UsesRemaining = k.UsesRemaining - 1 + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + + return k, nil +} + // == private == func upsertEnrollmentKey(k *models.EnrollmentKey) error { @@ -106,7 +129,7 @@ func getUniqueEnrollmentID() (string, error) { return "", err } newID := ncutils.MakeRandomString(32) - for _, ok := currentKeys[newID]; !ok; { + for _, ok := currentKeys[newID]; ok; { newID = ncutils.MakeRandomString(32) } return newID, nil @@ -119,6 +142,9 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { return nil, err } } + if records == nil { + records = make(map[string]string) + } currentKeys := make(map[string]*models.EnrollmentKey) if len(records) > 0 { for k := range records { diff --git a/logic/enrollment_key_test.go b/logic/enrollmentkey_test.go similarity index 56% rename from logic/enrollment_key_test.go rename to logic/enrollmentkey_test.go index caf110f4..935a6c8f 100644 --- a/logic/enrollment_key_test.go +++ b/logic/enrollmentkey_test.go @@ -8,12 +8,13 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCreate_EnrollmentKey(t *testing.T) { +func TestCreateEnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() t.Run("Can_Not_Create_Key", func(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) assert.Nil(t, newKey) + assert.NotNil(t, err) assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { @@ -50,20 +51,53 @@ func TestCreate_EnrollmentKey(t *testing.T) { func TestDelete_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + t.Run("Can_Delete_Key", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + err := DeleteEnrollmentKey(newKey.Value) + assert.Nil(t, err) + oldKey, err := GetEnrollmentKey(newKey.Value) + assert.Nil(t, oldKey) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + }) + t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { + err := DeleteEnrollmentKey("notakey") + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + }) + removeAllEnrollments() } func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true) + t.Run("Check_initial_uses", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + assert.Equal(t, newKey.UsesRemaining, 1) + }) + t.Run("Check can decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 1) + k, err := DecrementEnrollmentKey(newKey.Value) + assert.Nil(t, err) + newKey = k + }) + t.Run("Check can not decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 0) + _, err := DecrementEnrollmentKey(newKey.Value) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining) + }) + removeAllEnrollments() } -func TestValidity_EnrollmentKey(t *testing.T) { - database.InitializeDatabase() - defer database.CloseDB() +// func TestValidity_EnrollmentKey(t *testing.T) { +// database.InitializeDatabase() +// defer database.CloseDB() -} +// } func removeAllEnrollments() { database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) diff --git a/logic/host_test.go b/logic/host_test.go index fdde345e..1b8320af 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -1,38 +1,19 @@ package logic import ( - "context" "net" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/matryer/is" ) -func TestMain(m *testing.M) { +func TestCheckPorts(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - CreateAdmin(&models.User{ - UserName: "admin", - Password: "password", - IsAdmin: true, - Networks: []string{}, - Groups: []string{}, - }) - peerUpdate := make(chan *models.Node) - go ManageZombies(context.Background(), peerUpdate) - go func() { - for update := range peerUpdate { - //do nothing - logger.Log(3, "received node update", update.Action) - } - }() -} -func TestCheckPorts(t *testing.T) { h := models.Host{ ID: uuid.New(), EndpointIP: net.ParseIP("192.168.1.1"), diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 0072aca0..394a25c0 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -1,6 +1,8 @@ package models -import "time" +import ( + "time" +) // EnrollmentKeyLength - the length of an enrollment key const EnrollmentKeyLength = 32 @@ -23,7 +25,7 @@ func (k *EnrollmentKey) IsValid() bool { if k.UsesRemaining > 0 { return true } - if !k.Expiration.IsZero() && time.Now().After(k.Expiration) { + if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) { return true }