added try to use func and edited tests

This commit is contained in:
0xdcarns 2023-02-15 15:52:58 -05:00
parent db4ea9faa4
commit 0e5e34ef0c
3 changed files with 76 additions and 27 deletions

View file

@ -2,6 +2,7 @@ package logic
import (
"encoding/json"
"errors"
"fmt"
"time"
@ -12,15 +13,15 @@ import (
// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages
var EnrollmentKeyErrors = struct {
InvalidCreate string
NoKeyFound string
InvalidKey string
NoUsesRemaining string
InvalidCreate error
NoKeyFound error
InvalidKey error
NoUsesRemaining error
}{
InvalidCreate: "invalid enrollment key created",
NoKeyFound: "no enrollmentkey found",
InvalidKey: "invalid key provided",
NoUsesRemaining: "no uses remaining",
InvalidCreate: fmt.Errorf("invalid enrollment key created"),
NoKeyFound: fmt.Errorf("no enrollmentkey found"),
InvalidKey: fmt.Errorf("invalid key provided"),
NoUsesRemaining: fmt.Errorf("no uses remaining"),
}
// CreateEnrollmentKey - creates a new enrollment key in db
@ -50,7 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
k.Tags = tags
}
if ok := k.Validate(); !ok {
return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate)
return nil, EnrollmentKeyErrors.InvalidCreate
}
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
@ -81,7 +82,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
if key, ok := currentKeys[value]; ok {
return key, nil
}
return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound)
return nil, EnrollmentKeyErrors.NoKeyFound
}
// DeleteEnrollmentKey - delete's a given enrollment key by value
@ -93,14 +94,31 @@ func DeleteEnrollmentKey(value string) error {
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
}
// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining
func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
// TryToUseEnrollmentKey - checks first if key can be decremented
// returns true if it is decremented or isvalid
func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
key, err := decrementEnrollmentKey(k.Value)
if err != nil {
if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) {
return k.IsValid()
}
} else {
k.UsesRemaining = key.UsesRemaining
return true
}
return false
}
// == private ==
// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
k, err := GetEnrollmentKey(value)
if err != nil {
return nil, err
}
if k.UsesRemaining == 0 {
return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining)
return nil, EnrollmentKeyErrors.NoUsesRemaining
}
k.UsesRemaining = k.UsesRemaining - 1
if err = upsertEnrollmentKey(k); err != nil {
@ -110,11 +128,9 @@ func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
return k, nil
}
// == private ==
func upsertEnrollmentKey(k *models.EnrollmentKey) error {
if k == nil {
return fmt.Errorf(EnrollmentKeyErrors.InvalidKey)
return EnrollmentKeyErrors.InvalidKey
}
data, err := json.Marshal(k)
if err != nil {

View file

@ -15,7 +15,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate)
assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
@ -59,12 +59,12 @@ func TestDelete_EnrollmentKey(t *testing.T) {
oldKey, err := GetEnrollmentKey(newKey.Value)
assert.Nil(t, oldKey)
assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound)
assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
})
t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
err := DeleteEnrollmentKey("notakey")
assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound)
assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
})
removeAllEnrollments()
}
@ -72,32 +72,57 @@ func TestDelete_EnrollmentKey(t *testing.T) {
func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true)
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1)
})
t.Run("Check can decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 1)
k, err := DecrementEnrollmentKey(newKey.Value)
k, err := decrementEnrollmentKey(newKey.Value)
assert.Nil(t, err)
newKey = k
})
t.Run("Check can not decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 0)
_, err := DecrementEnrollmentKey(newKey.Value)
_, err := decrementEnrollmentKey(newKey.Value)
assert.NotNil(t, err)
assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining)
assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining)
})
removeAllEnrollments()
}
// func TestValidity_EnrollmentKey(t *testing.T) {
// database.InitializeDatabase()
// defer database.CloseDB()
func TestUsability_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
assert.True(t, ok)
assert.Equal(t, 0, key1.UsesRemaining)
})
// }
t.Run("Check if valid time key can be used", func(t *testing.T) {
assert.True(t, !key2.Expiration.IsZero())
ok := TryToUseEnrollmentKey(key2)
assert.True(t, ok)
})
t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
assert.True(t, key3.Unlimited)
ok := TryToUseEnrollmentKey(key3)
assert.True(t, ok)
})
t.Run("check invalid key can not be used", func(t *testing.T) {
ok := TryToUseEnrollmentKey(key1)
assert.False(t, ok)
})
}
func removeAllEnrollments() {
database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)

View file

@ -1,6 +1,7 @@
package logic
import (
"context"
"net"
"testing"
@ -13,6 +14,13 @@ import (
func TestCheckPorts(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
peerUpdate := make(chan *models.Node)
go ManageZombies(context.Background(), peerUpdate)
go func() {
for _ = range peerUpdate {
//do nothing
}
}()
h := models.Host{
ID: uuid.New(),