diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 4dc021fa..e0fe3b58 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -2,6 +2,7 @@ package logic import ( "encoding/json" + "errors" "fmt" "time" @@ -12,15 +13,15 @@ import ( // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages var EnrollmentKeyErrors = struct { - InvalidCreate string - NoKeyFound string - InvalidKey string - NoUsesRemaining string + InvalidCreate error + NoKeyFound error + InvalidKey error + NoUsesRemaining error }{ - InvalidCreate: "invalid enrollment key created", - NoKeyFound: "no enrollmentkey found", - InvalidKey: "invalid key provided", - NoUsesRemaining: "no uses remaining", + InvalidCreate: fmt.Errorf("invalid enrollment key created"), + NoKeyFound: fmt.Errorf("no enrollmentkey found"), + InvalidKey: fmt.Errorf("invalid key provided"), + NoUsesRemaining: fmt.Errorf("no uses remaining"), } // CreateEnrollmentKey - creates a new enrollment key in db @@ -50,7 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string k.Tags = tags } if ok := k.Validate(); !ok { - return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate) + return nil, EnrollmentKeyErrors.InvalidCreate } if err = upsertEnrollmentKey(k); err != nil { return nil, err @@ -81,7 +82,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { if key, ok := currentKeys[value]; ok { return key, nil } - return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound) + return nil, EnrollmentKeyErrors.NoKeyFound } // DeleteEnrollmentKey - delete's a given enrollment key by value @@ -93,14 +94,31 @@ func DeleteEnrollmentKey(value string) error { return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) } -// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining -func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { +// TryToUseEnrollmentKey - checks first if key can be decremented +// returns true if it is decremented or isvalid +func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { + key, err := decrementEnrollmentKey(k.Value) + if err != nil { + if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) { + return k.IsValid() + } + } else { + k.UsesRemaining = key.UsesRemaining + return true + } + return false +} + +// == private == + +// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining +func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { k, err := GetEnrollmentKey(value) if err != nil { return nil, err } if k.UsesRemaining == 0 { - return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining) + return nil, EnrollmentKeyErrors.NoUsesRemaining } k.UsesRemaining = k.UsesRemaining - 1 if err = upsertEnrollmentKey(k); err != nil { @@ -110,11 +128,9 @@ func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { return k, nil } -// == private == - func upsertEnrollmentKey(k *models.EnrollmentKey) error { if k == nil { - return fmt.Errorf(EnrollmentKeyErrors.InvalidKey) + return EnrollmentKeyErrors.InvalidKey } data, err := json.Marshal(k) if err != nil { diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 935a6c8f..ddee0d43 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -15,7 +15,7 @@ func TestCreateEnrollmentKey(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) assert.Nil(t, newKey) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) + assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) @@ -59,12 +59,12 @@ func TestDelete_EnrollmentKey(t *testing.T) { oldKey, err := GetEnrollmentKey(newKey.Value) assert.Nil(t, oldKey) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) }) t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { err := DeleteEnrollmentKey("notakey") assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) }) removeAllEnrollments() } @@ -72,32 +72,57 @@ func TestDelete_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true) + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) t.Run("Check_initial_uses", func(t *testing.T) { assert.True(t, newKey.IsValid()) assert.Equal(t, newKey.UsesRemaining, 1) }) t.Run("Check can decrement", func(t *testing.T) { assert.Equal(t, newKey.UsesRemaining, 1) - k, err := DecrementEnrollmentKey(newKey.Value) + k, err := decrementEnrollmentKey(newKey.Value) assert.Nil(t, err) newKey = k }) t.Run("Check can not decrement", func(t *testing.T) { assert.Equal(t, newKey.UsesRemaining, 0) - _, err := DecrementEnrollmentKey(newKey.Value) + _, err := decrementEnrollmentKey(newKey.Value) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining) + assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining) }) removeAllEnrollments() } -// func TestValidity_EnrollmentKey(t *testing.T) { -// database.InitializeDatabase() -// defer database.CloseDB() +func TestUsability_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + t.Run("Check if valid use key can be used", func(t *testing.T) { + assert.Equal(t, key1.UsesRemaining, 1) + ok := TryToUseEnrollmentKey(key1) + assert.True(t, ok) + assert.Equal(t, 0, key1.UsesRemaining) + }) -// } + t.Run("Check if valid time key can be used", func(t *testing.T) { + assert.True(t, !key2.Expiration.IsZero()) + ok := TryToUseEnrollmentKey(key2) + assert.True(t, ok) + }) + + t.Run("Check if valid unlimited key can be used", func(t *testing.T) { + assert.True(t, key3.Unlimited) + ok := TryToUseEnrollmentKey(key3) + assert.True(t, ok) + }) + + t.Run("check invalid key can not be used", func(t *testing.T) { + ok := TryToUseEnrollmentKey(key1) + assert.False(t, ok) + }) +} func removeAllEnrollments() { database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) diff --git a/logic/host_test.go b/logic/host_test.go index 1b8320af..94ce4b9d 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -1,6 +1,7 @@ package logic import ( + "context" "net" "testing" @@ -13,6 +14,13 @@ import ( func TestCheckPorts(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() + peerUpdate := make(chan *models.Node) + go ManageZombies(context.Background(), peerUpdate) + go func() { + for _ = range peerUpdate { + //do nothing + } + }() h := models.Host{ ID: uuid.New(),