From a5e7147b69e235a22bdfec5532ae18a33160f130 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Tue, 14 Feb 2023 17:21:51 -0500 Subject: [PATCH 01/23] initial commit, began unit tests --- database/database.go | 3 + logic/enrollment_key.go | 133 +++++++++++++++++++++++++++++++++++ logic/enrollment_key_test.go | 70 ++++++++++++++++++ models/enrollment_key.go | 40 +++++++++++ 4 files changed, 246 insertions(+) create mode 100644 logic/enrollment_key.go create mode 100644 logic/enrollment_key_test.go create mode 100644 models/enrollment_key.go diff --git a/database/database.go b/database/database.go index 283a829b..38d4fd40 100644 --- a/database/database.go +++ b/database/database.go @@ -57,6 +57,8 @@ const ( CACHE_TABLE_NAME = "cache" // HOSTS_TABLE_NAME - the table name for hosts HOSTS_TABLE_NAME = "hosts" + // ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys + ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys" // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -138,6 +140,7 @@ func createTables() { createTable(USER_GROUPS_TABLE_NAME) createTable(CACHE_TABLE_NAME) createTable(HOSTS_TABLE_NAME) + createTable(ENROLLMENT_KEYS_TABLE_NAME) } func createTable(tableName string) error { diff --git a/logic/enrollment_key.go b/logic/enrollment_key.go new file mode 100644 index 00000000..02315b62 --- /dev/null +++ b/logic/enrollment_key.go @@ -0,0 +1,133 @@ +package logic + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/ncutils" +) + +// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages +var EnrollmentKeyErrors = struct { + InvalidCreate string + NoKeyFound string + InvalidKey string +}{ + InvalidCreate: "invalid enrollment key created", + NoKeyFound: "no enrollmentkey found", + InvalidKey: "invalid key provided", +} + +// CreateEnrollmentKey - creates a new enrollment key in db +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) { + newKeyID, err := getUniqueEnrollmentID() + if err != nil { + return nil, err + } + k = &models.EnrollmentKey{ + Value: newKeyID, + Expiration: time.Time{}, + UsesRemaining: 0, + Unlimited: unlimited, + Networks: []string{}, + Tags: []string{}, + } + if uses > 0 { + k.UsesRemaining = uses + } + if !expiration.IsZero() { + k.Expiration = expiration + } + if len(networks) > 0 { + k.Networks = networks + } + if len(tags) > 0 { + k.Tags = tags + } + if ok := k.Validate(); !ok { + return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate) + } + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + return +} + +// GetAllEnrollmentKeys - fetches all enrollment keys from DB +func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + var currentKeysList = make([]*models.EnrollmentKey, len(currentKeys)) + for k := range currentKeys { + currentKeysList = append(currentKeysList, currentKeys[k]) + } + return currentKeysList, nil +} + +// GetEnrollmentKey - fetches a single enrollment key +// returns nil and error if not found +func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + if key, ok := currentKeys[value]; ok { + return key, nil + } + return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound) +} + +// DeleteEnrollmentKey - delete's a given enrollment key by value +func DeleteEnrollmentKey(value string) error { + return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) +} + +// == private == + +func upsertEnrollmentKey(k *models.EnrollmentKey) error { + if k == nil { + return fmt.Errorf(EnrollmentKeyErrors.InvalidKey) + } + data, err := json.Marshal(k) + if err != nil { + return err + } + return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) +} + +func getUniqueEnrollmentID() (string, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return "", err + } + newID := ncutils.MakeRandomString(32) + for _, ok := currentKeys[newID]; !ok; { + newID = ncutils.MakeRandomString(32) + } + return newID, nil +} + +func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { + records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) + if err != nil { + if !database.IsEmptyRecord(err) { + return nil, err + } + } + currentKeys := make(map[string]*models.EnrollmentKey) + if len(records) > 0 { + for k := range records { + var currentKey models.EnrollmentKey + if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil { + continue + } + currentKeys[k] = ¤tKey + } + } + return currentKeys, nil +} diff --git a/logic/enrollment_key_test.go b/logic/enrollment_key_test.go new file mode 100644 index 00000000..caf110f4 --- /dev/null +++ b/logic/enrollment_key_test.go @@ -0,0 +1,70 @@ +package logic + +import ( + "testing" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/stretchr/testify/assert" +) + +func TestCreate_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + t.Run("Can_Not_Create_Key", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) + assert.Nil(t, newKey) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) + }) + t.Run("Can_Create_Key_Uses", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + assert.Nil(t, err) + assert.Equal(t, 1, newKey.UsesRemaining) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Time", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Networks) == 2) + }) + t.Run("Can_Create_Key_WithTags", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Tags) == 2) + }) + removeAllEnrollments() +} + +func TestDelete_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func TestDecrement_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func TestValidity_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + +} + +func removeAllEnrollments() { + database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) +} diff --git a/models/enrollment_key.go b/models/enrollment_key.go new file mode 100644 index 00000000..0072aca0 --- /dev/null +++ b/models/enrollment_key.go @@ -0,0 +1,40 @@ +package models + +import "time" + +// EnrollmentKeyLength - the length of an enrollment key +const EnrollmentKeyLength = 32 + +// EnrollmentKey - the +type EnrollmentKey struct { + Expiration time.Time `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Value string `json:"value"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` +} + +// EnrollmentKey.IsValid - checks if the key is still valid to use +func (k *EnrollmentKey) IsValid() bool { + if k == nil { + return false + } + if k.UsesRemaining > 0 { + return true + } + if !k.Expiration.IsZero() && time.Now().After(k.Expiration) { + return true + } + + return k.Unlimited +} + +// EnrollmentKey.Validate - validate's an EnrollmentKey +// should be used during creation +func (k *EnrollmentKey) Validate() bool { + return k.Networks != nil && + k.Tags != nil && + len(k.Value) == EnrollmentKeyLength && + k.IsValid() +} From db4ea9faa43aea49ec4565162ea8e2481d3f625d Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Wed, 15 Feb 2023 15:27:26 -0500 Subject: [PATCH 02/23] completed crud unit tests --- config/config_test.go | 8 ++-- logic/accesskeys_test.go | 2 - logic/{enrollment_key.go => enrollmentkey.go} | 40 +++++++++++++--- ...ment_key_test.go => enrollmentkey_test.go} | 46 ++++++++++++++++--- logic/host_test.go | 21 +-------- models/enrollment_key.go | 6 ++- 6 files changed, 82 insertions(+), 41 deletions(-) rename logic/{enrollment_key.go => enrollmentkey.go} (78%) rename logic/{enrollment_key_test.go => enrollmentkey_test.go} (56%) diff --git a/config/config_test.go b/config/config_test.go index 2a8205fa..04c2c144 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,7 @@ -//Environment file for getting variables -//Currently the only thing it does is set the master password -//Should probably have it take over functions from OS such as port and mongodb connection details -//Reads from the config/environments/dev.yaml file by default +// Environment file for getting variables +// Currently the only thing it does is set the master password +// Should probably have it take over functions from OS such as port and mongodb connection details +// Reads from the config/environments/dev.yaml file by default package config import ( diff --git a/logic/accesskeys_test.go b/logic/accesskeys_test.go index 030508a9..fe5443fc 100644 --- a/logic/accesskeys_test.go +++ b/logic/accesskeys_test.go @@ -5,7 +5,6 @@ import "testing" func Test_genKeyName(t *testing.T) { for i := 0; i < 100; i++ { kname := genKeyName() - t.Log(kname) if len(kname) != 20 { t.Fatalf("improper length of key name, expected 20 got :%d", len(kname)) } @@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) { func Test_genKey(t *testing.T) { for i := 0; i < 100; i++ { kname := GenKey() - t.Log(kname) if len(kname) != 16 { t.Fatalf("improper length of key name, expected 16 got :%d", len(kname)) } diff --git a/logic/enrollment_key.go b/logic/enrollmentkey.go similarity index 78% rename from logic/enrollment_key.go rename to logic/enrollmentkey.go index 02315b62..4dc021fa 100644 --- a/logic/enrollment_key.go +++ b/logic/enrollmentkey.go @@ -12,13 +12,15 @@ import ( // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages var EnrollmentKeyErrors = struct { - InvalidCreate string - NoKeyFound string - InvalidKey string + InvalidCreate string + NoKeyFound string + InvalidKey string + NoUsesRemaining string }{ - InvalidCreate: "invalid enrollment key created", - NoKeyFound: "no enrollmentkey found", - InvalidKey: "invalid key provided", + InvalidCreate: "invalid enrollment key created", + NoKeyFound: "no enrollmentkey found", + InvalidKey: "invalid key provided", + NoUsesRemaining: "no uses remaining", } // CreateEnrollmentKey - creates a new enrollment key in db @@ -84,9 +86,30 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { // DeleteEnrollmentKey - delete's a given enrollment key by value func DeleteEnrollmentKey(value string) error { + _, err := GetEnrollmentKey(value) + if err != nil { + return err + } return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) } +// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining +func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { + k, err := GetEnrollmentKey(value) + if err != nil { + return nil, err + } + if k.UsesRemaining == 0 { + return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining) + } + k.UsesRemaining = k.UsesRemaining - 1 + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + + return k, nil +} + // == private == func upsertEnrollmentKey(k *models.EnrollmentKey) error { @@ -106,7 +129,7 @@ func getUniqueEnrollmentID() (string, error) { return "", err } newID := ncutils.MakeRandomString(32) - for _, ok := currentKeys[newID]; !ok; { + for _, ok := currentKeys[newID]; ok; { newID = ncutils.MakeRandomString(32) } return newID, nil @@ -119,6 +142,9 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { return nil, err } } + if records == nil { + records = make(map[string]string) + } currentKeys := make(map[string]*models.EnrollmentKey) if len(records) > 0 { for k := range records { diff --git a/logic/enrollment_key_test.go b/logic/enrollmentkey_test.go similarity index 56% rename from logic/enrollment_key_test.go rename to logic/enrollmentkey_test.go index caf110f4..935a6c8f 100644 --- a/logic/enrollment_key_test.go +++ b/logic/enrollmentkey_test.go @@ -8,12 +8,13 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCreate_EnrollmentKey(t *testing.T) { +func TestCreateEnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() t.Run("Can_Not_Create_Key", func(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) assert.Nil(t, newKey) + assert.NotNil(t, err) assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { @@ -50,20 +51,53 @@ func TestCreate_EnrollmentKey(t *testing.T) { func TestDelete_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + t.Run("Can_Delete_Key", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + err := DeleteEnrollmentKey(newKey.Value) + assert.Nil(t, err) + oldKey, err := GetEnrollmentKey(newKey.Value) + assert.Nil(t, oldKey) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + }) + t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { + err := DeleteEnrollmentKey("notakey") + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + }) + removeAllEnrollments() } func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true) + t.Run("Check_initial_uses", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + assert.Equal(t, newKey.UsesRemaining, 1) + }) + t.Run("Check can decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 1) + k, err := DecrementEnrollmentKey(newKey.Value) + assert.Nil(t, err) + newKey = k + }) + t.Run("Check can not decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 0) + _, err := DecrementEnrollmentKey(newKey.Value) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining) + }) + removeAllEnrollments() } -func TestValidity_EnrollmentKey(t *testing.T) { - database.InitializeDatabase() - defer database.CloseDB() +// func TestValidity_EnrollmentKey(t *testing.T) { +// database.InitializeDatabase() +// defer database.CloseDB() -} +// } func removeAllEnrollments() { database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) diff --git a/logic/host_test.go b/logic/host_test.go index fdde345e..1b8320af 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -1,38 +1,19 @@ package logic import ( - "context" "net" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/matryer/is" ) -func TestMain(m *testing.M) { +func TestCheckPorts(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - CreateAdmin(&models.User{ - UserName: "admin", - Password: "password", - IsAdmin: true, - Networks: []string{}, - Groups: []string{}, - }) - peerUpdate := make(chan *models.Node) - go ManageZombies(context.Background(), peerUpdate) - go func() { - for update := range peerUpdate { - //do nothing - logger.Log(3, "received node update", update.Action) - } - }() -} -func TestCheckPorts(t *testing.T) { h := models.Host{ ID: uuid.New(), EndpointIP: net.ParseIP("192.168.1.1"), diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 0072aca0..394a25c0 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -1,6 +1,8 @@ package models -import "time" +import ( + "time" +) // EnrollmentKeyLength - the length of an enrollment key const EnrollmentKeyLength = 32 @@ -23,7 +25,7 @@ func (k *EnrollmentKey) IsValid() bool { if k.UsesRemaining > 0 { return true } - if !k.Expiration.IsZero() && time.Now().After(k.Expiration) { + if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) { return true } From 0e5e34ef0c2ed326bb1a8bb8d0176988c8bdc862 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Wed, 15 Feb 2023 15:52:58 -0500 Subject: [PATCH 03/23] added try to use func and edited tests --- logic/enrollmentkey.go | 48 ++++++++++++++++++++++++------------- logic/enrollmentkey_test.go | 47 +++++++++++++++++++++++++++--------- logic/host_test.go | 8 +++++++ 3 files changed, 76 insertions(+), 27 deletions(-) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 4dc021fa..e0fe3b58 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -2,6 +2,7 @@ package logic import ( "encoding/json" + "errors" "fmt" "time" @@ -12,15 +13,15 @@ import ( // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages var EnrollmentKeyErrors = struct { - InvalidCreate string - NoKeyFound string - InvalidKey string - NoUsesRemaining string + InvalidCreate error + NoKeyFound error + InvalidKey error + NoUsesRemaining error }{ - InvalidCreate: "invalid enrollment key created", - NoKeyFound: "no enrollmentkey found", - InvalidKey: "invalid key provided", - NoUsesRemaining: "no uses remaining", + InvalidCreate: fmt.Errorf("invalid enrollment key created"), + NoKeyFound: fmt.Errorf("no enrollmentkey found"), + InvalidKey: fmt.Errorf("invalid key provided"), + NoUsesRemaining: fmt.Errorf("no uses remaining"), } // CreateEnrollmentKey - creates a new enrollment key in db @@ -50,7 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string k.Tags = tags } if ok := k.Validate(); !ok { - return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate) + return nil, EnrollmentKeyErrors.InvalidCreate } if err = upsertEnrollmentKey(k); err != nil { return nil, err @@ -81,7 +82,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { if key, ok := currentKeys[value]; ok { return key, nil } - return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound) + return nil, EnrollmentKeyErrors.NoKeyFound } // DeleteEnrollmentKey - delete's a given enrollment key by value @@ -93,14 +94,31 @@ func DeleteEnrollmentKey(value string) error { return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) } -// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining -func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { +// TryToUseEnrollmentKey - checks first if key can be decremented +// returns true if it is decremented or isvalid +func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { + key, err := decrementEnrollmentKey(k.Value) + if err != nil { + if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) { + return k.IsValid() + } + } else { + k.UsesRemaining = key.UsesRemaining + return true + } + return false +} + +// == private == + +// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining +func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { k, err := GetEnrollmentKey(value) if err != nil { return nil, err } if k.UsesRemaining == 0 { - return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining) + return nil, EnrollmentKeyErrors.NoUsesRemaining } k.UsesRemaining = k.UsesRemaining - 1 if err = upsertEnrollmentKey(k); err != nil { @@ -110,11 +128,9 @@ func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { return k, nil } -// == private == - func upsertEnrollmentKey(k *models.EnrollmentKey) error { if k == nil { - return fmt.Errorf(EnrollmentKeyErrors.InvalidKey) + return EnrollmentKeyErrors.InvalidKey } data, err := json.Marshal(k) if err != nil { diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 935a6c8f..ddee0d43 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -15,7 +15,7 @@ func TestCreateEnrollmentKey(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) assert.Nil(t, newKey) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate) + assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) @@ -59,12 +59,12 @@ func TestDelete_EnrollmentKey(t *testing.T) { oldKey, err := GetEnrollmentKey(newKey.Value) assert.Nil(t, oldKey) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) }) t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { err := DeleteEnrollmentKey("notakey") assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) }) removeAllEnrollments() } @@ -72,32 +72,57 @@ func TestDelete_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true) + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) t.Run("Check_initial_uses", func(t *testing.T) { assert.True(t, newKey.IsValid()) assert.Equal(t, newKey.UsesRemaining, 1) }) t.Run("Check can decrement", func(t *testing.T) { assert.Equal(t, newKey.UsesRemaining, 1) - k, err := DecrementEnrollmentKey(newKey.Value) + k, err := decrementEnrollmentKey(newKey.Value) assert.Nil(t, err) newKey = k }) t.Run("Check can not decrement", func(t *testing.T) { assert.Equal(t, newKey.UsesRemaining, 0) - _, err := DecrementEnrollmentKey(newKey.Value) + _, err := decrementEnrollmentKey(newKey.Value) assert.NotNil(t, err) - assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining) + assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining) }) removeAllEnrollments() } -// func TestValidity_EnrollmentKey(t *testing.T) { -// database.InitializeDatabase() -// defer database.CloseDB() +func TestUsability_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + t.Run("Check if valid use key can be used", func(t *testing.T) { + assert.Equal(t, key1.UsesRemaining, 1) + ok := TryToUseEnrollmentKey(key1) + assert.True(t, ok) + assert.Equal(t, 0, key1.UsesRemaining) + }) -// } + t.Run("Check if valid time key can be used", func(t *testing.T) { + assert.True(t, !key2.Expiration.IsZero()) + ok := TryToUseEnrollmentKey(key2) + assert.True(t, ok) + }) + + t.Run("Check if valid unlimited key can be used", func(t *testing.T) { + assert.True(t, key3.Unlimited) + ok := TryToUseEnrollmentKey(key3) + assert.True(t, ok) + }) + + t.Run("check invalid key can not be used", func(t *testing.T) { + ok := TryToUseEnrollmentKey(key1) + assert.False(t, ok) + }) +} func removeAllEnrollments() { database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) diff --git a/logic/host_test.go b/logic/host_test.go index 1b8320af..94ce4b9d 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -1,6 +1,7 @@ package logic import ( + "context" "net" "testing" @@ -13,6 +14,13 @@ import ( func TestCheckPorts(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() + peerUpdate := make(chan *models.Node) + go ManageZombies(context.Background(), peerUpdate) + go func() { + for _ = range peerUpdate { + //do nothing + } + }() h := models.Host{ ID: uuid.New(), From 71ce2caabd7aaf6af0d948094b6c4dbcf07f05ab Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Wed, 15 Feb 2023 16:32:16 -0500 Subject: [PATCH 04/23] added tokenization + detokenization --- controllers/controller.go | 1 + controllers/enrollmentkeys.go | 96 +++++++++++++++++++++++++++++++++++ logic/enrollmentkey.go | 62 +++++++++++++++++++--- models/enrollment_key.go | 8 +++ 4 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 controllers/enrollmentkeys.go diff --git a/controllers/controller.go b/controllers/controller.go index 571671bd..e4d0ef50 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -29,6 +29,7 @@ var HttpHandlers = []interface{}{ ipHandlers, loggerHandlers, hostHandlers, + enrollmentKeyHandlers, } // HandleRESTRequests - handles the rest requests diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go new file mode 100644 index 00000000..44773aeb --- /dev/null +++ b/controllers/enrollmentkeys.go @@ -0,0 +1,96 @@ +package controller + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/servercfg" +) + +func enrollmentKeyHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/host/register", logic.SecurityCheck(true, http.HandlerFunc(handleHostRegister))).Methods(http.MethodPost) +} + +// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys +// +// Lists all hosts. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: getEnrollmentKeysSlice +func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { + currentKeys, err := logic.GetAllEnrollmentKeys() + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + for i := range currentKeys { + if err = logic.Tokenize(currentKeys[i], servercfg.GetServer()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } + // return JSON/API formatted hosts + logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(currentKeys) +} + +// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey +// +// Deletes a Netclient host from Netmaker server. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: deleteEnrollmentKeyResponse +func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + keyID := params["keyID"] + err := logic.DeleteEnrollmentKey(keyID) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID) + w.WriteHeader(http.StatusOK) +} + +// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey +// +// Deletes a Netclient host from Netmaker server. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: hostRegisterResponse +func handleHostRegister(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + keyID := params["keyID"] + err := logic.DeleteEnrollmentKey(keyID) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID) + w.WriteHeader(http.StatusOK) +} diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index e0fe3b58..778f28de 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -1,6 +1,7 @@ package logic import ( + b64 "encoding/base64" "encoding/json" "errors" "fmt" @@ -13,15 +14,19 @@ import ( // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages var EnrollmentKeyErrors = struct { - InvalidCreate error - NoKeyFound error - InvalidKey error - NoUsesRemaining error + InvalidCreate error + NoKeyFound error + InvalidKey error + NoUsesRemaining error + FailedToTokenize error + FailedToDeTokenize error }{ - InvalidCreate: fmt.Errorf("invalid enrollment key created"), - NoKeyFound: fmt.Errorf("no enrollmentkey found"), - InvalidKey: fmt.Errorf("invalid key provided"), - NoUsesRemaining: fmt.Errorf("no uses remaining"), + InvalidCreate: fmt.Errorf("invalid enrollment key created"), + NoKeyFound: fmt.Errorf("no enrollmentkey found"), + InvalidKey: fmt.Errorf("invalid key provided"), + NoUsesRemaining: fmt.Errorf("no uses remaining"), + FailedToTokenize: fmt.Errorf("failed to tokenize"), + FailedToDeTokenize: fmt.Errorf("failed to detokenize"), } // CreateEnrollmentKey - creates a new enrollment key in db @@ -109,6 +114,47 @@ func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { return false } +// Tokenize - tokenizes an enrollment key to be used via registration +// and attaches it to the Token field on the struct +func Tokenize(k *models.EnrollmentKey, serverAddr string) error { + if len(serverAddr) == 0 { + return EnrollmentKeyErrors.FailedToTokenize + } + newToken := models.EnrollmentToken{ + Server: serverAddr, + Value: k.Value, + } + data, err := json.Marshal(&newToken) + if err != nil { + return err + } + k.Token = b64.StdEncoding.EncodeToString([]byte(data)) + return nil +} + +// DeTokenize - detokenizes a base64 encoded string +// and finds the associated enrollment key +func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { + if len(b64Token) == 0 { + return nil, EnrollmentKeyErrors.FailedToDeTokenize + } + tokenData, err := b64.StdEncoding.DecodeString(b64Token) + if err != nil { + return nil, err + } + + var newToken models.EnrollmentToken + err = json.Unmarshal(tokenData, &newToken) + if err != nil { + return nil, err + } + k, err := GetEnrollmentKey(newToken.Value) + if err != nil { + return nil, err + } + return k, nil +} + // == private == // decrementEnrollmentKey - decrements the uses on a key if above 0 remaining diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 394a25c0..ae89ee48 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -4,6 +4,13 @@ import ( "time" ) +// EnrollmentToken - the tokenized version of an enrollmentkey; +// to be used for host registration +type EnrollmentToken struct { + Server string `json:"value"` + Value string `json:"value"` +} + // EnrollmentKeyLength - the length of an enrollment key const EnrollmentKeyLength = 32 @@ -15,6 +22,7 @@ type EnrollmentKey struct { Networks []string `json:"networks"` Unlimited bool `json:"unlimited"` Tags []string `json:"tags"` + Token string `json:"token,omitempty"` // B64 value of EnrollmentToken } // EnrollmentKey.IsValid - checks if the key is still valid to use From 442b32e0d93b7bd4d7981357630d66de73cc474d Mon Sep 17 00:00:00 2001 From: walkerwmanuel Date: Thu, 16 Feb 2023 10:56:13 -0500 Subject: [PATCH 05/23] Wrote test to test Enrolment Keys --- logic/enrollmentkey.go | 5 +-- logic/enrollmentkey_test.go | 71 +++++++++++++++++++++++++++++++++++++ models/enrollment_key.go | 2 +- 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 778f28de..6d865e85 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -117,7 +117,7 @@ func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { // Tokenize - tokenizes an enrollment key to be used via registration // and attaches it to the Token field on the struct func Tokenize(k *models.EnrollmentKey, serverAddr string) error { - if len(serverAddr) == 0 { + if len(serverAddr) == 0 || k == nil { return EnrollmentKeyErrors.FailedToTokenize } newToken := models.EnrollmentToken{ @@ -128,8 +128,9 @@ func Tokenize(k *models.EnrollmentKey, serverAddr string) error { if err != nil { return err } - k.Token = b64.StdEncoding.EncodeToString([]byte(data)) + k.Token = b64.StdEncoding.EncodeToString(data) return nil + } // DeTokenize - detokenizes a base64 encoded string diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index ddee0d43..f89f7f70 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -1,6 +1,7 @@ package logic import ( + "fmt" "testing" "time" @@ -127,3 +128,73 @@ func TestUsability_EnrollmentKey(t *testing.T) { func removeAllEnrollments() { database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) } + +//Test that cheks if it can tokenize +//Test that cheks if it can't tokenize + +func TestTokenize_EnrollmentKeys(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + fmt.Println(newKey.Value) + const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" + const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" + const serverAddr = "api.myserver.com" + t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) { + err := Tokenize(nil, "ServerAddress") + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentKeyErrors.FailedToTokenize) + }) + t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) { + err := Tokenize(newKey, "") + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentKeyErrors.FailedToTokenize) + }) + + t.Run("Can_Tokenize", func(t *testing.T) { + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + assert.True(t, len(newKey.Token) > 0) + }) + + t.Run("Is_Correct_B64_Token", func(t *testing.T) { + newKey.Value = defaultValue + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + assert.Equal(t, newKey.Token, b64value) + }) + removeAllEnrollments() +} + +func TestDeTokenize_EnrollmentKeys(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + fmt.Println(newKey.Value) + const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" + const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" + const serverAddr = "api.myserver.com" + + t.Run("Can_Not_DeTokenize", func(t *testing.T) { + value, err := DeTokenize("") + assert.Nil(t, value) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentKeyErrors.FailedToDeTokenize) + }) + t.Run("Can_Not_Find_Key", func(t *testing.T) { + value, err := DeTokenize(b64Value) + assert.Nil(t, value) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) + }) + t.Run("Can_DeTokenize", func(t *testing.T) { + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + output, err := DeTokenize(newKey.Token) + assert.Nil(t, err) + assert.NotNil(t, output) + assert.Equal(t, newKey.Value, output.Value) + }) + + removeAllEnrollments() +} diff --git a/models/enrollment_key.go b/models/enrollment_key.go index ae89ee48..1a58bdc1 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -7,7 +7,7 @@ import ( // EnrollmentToken - the tokenized version of an enrollmentkey; // to be used for host registration type EnrollmentToken struct { - Server string `json:"value"` + Server string `json:"server"` Value string `json:"value"` } From 8f8c4f1df0682748b89ffe4075e792bd629c4d82 Mon Sep 17 00:00:00 2001 From: walkerwmanuel Date: Thu, 16 Feb 2023 10:57:18 -0500 Subject: [PATCH 06/23] Wrote test to test Enrolment Keys --- logic/enrollmentkey_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index f89f7f70..ae223ead 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -171,7 +171,7 @@ func TestDeTokenize_EnrollmentKeys(t *testing.T) { defer database.CloseDB() newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) fmt.Println(newKey.Value) - const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" + //const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" const serverAddr = "api.myserver.com" From 193ef6b6ed32ea65cd508a0ae832b501471e074e Mon Sep 17 00:00:00 2001 From: walkerwmanuel Date: Thu, 16 Feb 2023 11:08:43 -0500 Subject: [PATCH 07/23] removed print lines --- logic/enrollmentkey_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index ae223ead..3350f5d7 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -1,7 +1,6 @@ package logic import ( - "fmt" "testing" "time" @@ -136,7 +135,6 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) - fmt.Println(newKey.Value) const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" const serverAddr = "api.myserver.com" @@ -170,7 +168,6 @@ func TestDeTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) - fmt.Println(newKey.Value) //const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" const serverAddr = "api.myserver.com" From 607198d563239c20756c519e14e4bc5b4d61a57a Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 14:27:57 -0500 Subject: [PATCH 08/23] added host registration endpoint --- controllers/enrollmentkeys.go | 141 ++++++++++++++++++++++++++++++++-- logic/enrollmentkey.go | 4 +- logic/enrollmentkey_test.go | 7 +- models/enrollment_key.go | 13 +++- 4 files changed, 151 insertions(+), 14 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 44773aeb..72bcf2e5 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -2,18 +2,23 @@ package controller import ( "encoding/json" + "fmt" "net/http" + "time" "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" ) func enrollmentKeyHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost) r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet) r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/host/register", logic.SecurityCheck(true, http.HandlerFunc(handleHostRegister))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/host/register/{token}", logic.SecurityCheck(true, http.HandlerFunc(handleHostRegister))).Methods(http.MethodPost) } // swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys @@ -71,7 +76,45 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey +// swagger:route POST /api/v1/enrollment-keys enrollmentKeys createEnrollmentKey +// +// Creates an EnrollmentKey for hosts to use on Netmaker server. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: createEnrollmentKeyResponse +func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { + + var enrollmentKeyBody models.APIEnrollmentKey + + err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var newTime time.Time + if enrollmentKeyBody.Expiration > 0 { + newTime = time.Unix(enrollmentKeyBody.Expiration, 0) + } + + newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logger.Log(2, r.Header.Get("user"), "created enrollment key") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(newEnrollmentKey) +} + +// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys deleteEnrollmentKey // // Deletes a Netclient host from Netmaker server. // @@ -84,13 +127,99 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { // 200: hostRegisterResponse func handleHostRegister(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) - keyID := params["keyID"] - err := logic.DeleteEnrollmentKey(keyID) + token := params["token"] + // check if token exists + enrollmentKey, err := logic.DeTokenize(token) if err != nil { - logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) + logger.Log(0, "invalid enrollment key used", token, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + // get the host + var newHost models.Host + if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", + err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID) + // check if host already exists + if ok := logic.HostExists(&newHost); ok { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register") + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest")) + return + } + // version check + if !logic.IsVersionComptatible(newHost.Version) || newHost.TrafficKeyPublic == nil { + err := fmt.Errorf("incompatible netclient") + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + key, keyErr := logic.RetrievePublicTrafficKey() + if keyErr != nil { + logger.Log(0, "error retrieving key:", keyErr.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + // use the token + if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration") + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest")) + return + } + // register host + logic.CheckHostPorts(&newHost) + if err = logic.CreateHost(&newHost); err != nil { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + + // ready the response + server := servercfg.GetServerInfo() + server.TrafficKey = key + logger.Log(2, r.Header.Get("user"), "deleted enrollment key", token) w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&server) + // notify host of changes, peer and node updates + go checkNetRegAndHostUpdate(enrollmentKey.Networks, &newHost) +} + +// run through networks and send a host update +func checkNetRegAndHostUpdate(networks []string, h *models.Host) { + // publish host update through MQ + if servercfg.IsMessageQueueBackend() { + if err := mq.HostUpdate(&models.HostUpdate{ + Action: models.UpdateHost, + Host: *h, + }); err != nil { + logger.Log(0, "failed to send host update after registration:", h.ID.String(), err.Error()) + } + } + + for i := range networks { + if ok, _ := logic.NetworkExists(networks[i]); ok { + newNode, err := logic.UpdateHostNetwork(h, networks[i], true) + if err != nil { + logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, networks[i], err.Error()) + continue + } + logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) + if servercfg.IsMessageQueueBackend() { + if err = mq.HostUpdate(&models.HostUpdate{ + Action: models.JoinHostToNetwork, + Host: *h, + Node: *newNode, + }); err != nil { + logger.Log(0, "failed to send host update to", h.ID.String(), networks[i], err.Error()) + } + } + } + } + + if servercfg.IsMessageQueueBackend() { + if err := mq.PublishPeerUpdate(); err != nil { + logger.Log(0, "failed to publish peer update after host registration -", err.Error()) + } + } } diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 6d865e85..65f5503f 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -191,9 +191,9 @@ func getUniqueEnrollmentID() (string, error) { if err != nil { return "", err } - newID := ncutils.MakeRandomString(32) + newID := ncutils.MakeRandomString(models.EnrollmentKeyLength) for _, ok := currentKeys[newID]; ok; { - newID = ncutils.MakeRandomString(32) + newID = ncutils.MakeRandomString(models.EnrollmentKeyLength) } return newID, nil } diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 3350f5d7..7e0d5b79 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -135,8 +135,8 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) - const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" - const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" + const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" + const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) { err := Tokenize(nil, "ServerAddress") @@ -168,8 +168,7 @@ func TestDeTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) - //const defaultValue = "MwEtpqTSrGd4HTO3ahYDTExKAehh6udJ" - const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0V0cHFUU3JHZDRIVE8zYWhZRFRFeEtBZWhoNnVkSiJ9" + const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" t.Run("Can_Not_DeTokenize", func(t *testing.T) { diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 1a58bdc1..1cba2ec3 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -11,10 +11,10 @@ type EnrollmentToken struct { Value string `json:"value"` } -// EnrollmentKeyLength - the length of an enrollment key +// EnrollmentKeyLength - the length of an enrollment key - 62^16 unique possibilities const EnrollmentKeyLength = 32 -// EnrollmentKey - the +// EnrollmentKey - the key used to register hosts and join them to specific networks type EnrollmentKey struct { Expiration time.Time `json:"expiration"` UsesRemaining int `json:"uses_remaining"` @@ -25,6 +25,15 @@ type EnrollmentKey struct { Token string `json:"token,omitempty"` // B64 value of EnrollmentToken } +// APIEnrollmentKey - used to create enrollment keys via API +type APIEnrollmentKey struct { + Expiration int64 `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` +} + // EnrollmentKey.IsValid - checks if the key is still valid to use func (k *EnrollmentKey) IsValid() bool { if k == nil { From 9078608bd147e105f1490e3dbd6f4db10ff56066 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 15:13:40 -0500 Subject: [PATCH 09/23] fix initial map allocation --- controllers/enrollmentkeys.go | 8 +++++++- logic/enrollmentkey.go | 8 ++++---- logic/enrollmentkey_test.go | 10 ++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 72bcf2e5..3c5b5443 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -40,7 +40,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { return } for i := range currentKeys { - if err = logic.Tokenize(currentKeys[i], servercfg.GetServer()); err != nil { + if err = logic.Tokenize(¤tKeys[i], servercfg.GetServer()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return @@ -109,6 +109,12 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + + if err = logic.Tokenize(newEnrollmentKey, servercfg.GetServer()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } logger.Log(2, r.Header.Get("user"), "created enrollment key") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(newEnrollmentKey) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 65f5503f..bd6a7f0f 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -65,14 +65,14 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string } // GetAllEnrollmentKeys - fetches all enrollment keys from DB -func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { +func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) { currentKeys, err := getEnrollmentKeysMap() if err != nil { return nil, err } - var currentKeysList = make([]*models.EnrollmentKey, len(currentKeys)) + var currentKeysList = make([]models.EnrollmentKey, 0) for k := range currentKeys { - currentKeysList = append(currentKeysList, currentKeys[k]) + currentKeysList = append(currentKeysList, *currentKeys[k]) } return currentKeysList, nil } @@ -208,7 +208,7 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { if records == nil { records = make(map[string]string) } - currentKeys := make(map[string]*models.EnrollmentKey) + currentKeys := make(map[string]*models.EnrollmentKey, 0) if len(records) > 0 { for k := range records { var currentKey models.EnrollmentKey diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 7e0d5b79..cdfc6712 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" ) @@ -45,6 +46,15 @@ func TestCreateEnrollmentKey(t *testing.T) { assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Tags) == 2) }) + + t.Run("Can_Get_List_of_Keys", func(t *testing.T) { + keys, err := GetAllEnrollmentKeys() + assert.Nil(t, err) + assert.True(t, len(keys) > 0) + for i := range keys { + assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength) + } + }) removeAllEnrollments() } From d8c7ab980e5489d937ac28836c5745d81d50b6b5 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 15:41:23 -0500 Subject: [PATCH 10/23] fixed nil pointer from dereference in loop --- controllers/enrollmentkeys.go | 5 +++-- logic/enrollmentkey.go | 7 +++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 3c5b5443..311c94fd 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -40,13 +40,14 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { return } for i := range currentKeys { - if err = logic.Tokenize(¤tKeys[i], servercfg.GetServer()); err != nil { + currentKey := currentKeys[i] + if err = logic.Tokenize(currentKey, servercfg.GetServer()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } } - // return JSON/API formatted hosts + // return JSON/API formatted keys logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(currentKeys) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index bd6a7f0f..20cd14e8 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -65,14 +65,14 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string } // GetAllEnrollmentKeys - fetches all enrollment keys from DB -func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) { +func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { currentKeys, err := getEnrollmentKeysMap() if err != nil { return nil, err } - var currentKeysList = make([]models.EnrollmentKey, 0) + var currentKeysList = []*models.EnrollmentKey{} for k := range currentKeys { - currentKeysList = append(currentKeysList, *currentKeys[k]) + currentKeysList = append(currentKeysList, currentKeys[k]) } return currentKeysList, nil } @@ -130,7 +130,6 @@ func Tokenize(k *models.EnrollmentKey, serverAddr string) error { } k.Token = b64.StdEncoding.EncodeToString(data) return nil - } // DeTokenize - detokenizes a base64 encoded string From 6e1db0bb3ff82d28c7226944375a582bf32dea13 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 16:42:08 -0500 Subject: [PATCH 11/23] removed admin security check --- controllers/enrollmentkeys.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 311c94fd..0c346c10 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -18,7 +18,7 @@ func enrollmentKeyHandlers(r *mux.Router) { r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost) r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet) r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/host/register/{token}", logic.SecurityCheck(true, http.HandlerFunc(handleHostRegister))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost) } // swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys From 596de6b9e372eb82638f8d376977bfa27892fdbc Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 18:56:45 -0500 Subject: [PATCH 12/23] shortened name --- logic/enrollmentkey.go | 18 +++++++++--------- logic/enrollmentkey_test.go | 16 ++++++++-------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index 20cd14e8..ec1d3c8f 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -12,8 +12,8 @@ import ( "github.com/gravitl/netmaker/netclient/ncutils" ) -// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages -var EnrollmentKeyErrors = struct { +// EnrollmentErrors - struct for holding EnrollmentKey error messages +var EnrollmentErrors = struct { InvalidCreate error NoKeyFound error InvalidKey error @@ -56,7 +56,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string k.Tags = tags } if ok := k.Validate(); !ok { - return nil, EnrollmentKeyErrors.InvalidCreate + return nil, EnrollmentErrors.InvalidCreate } if err = upsertEnrollmentKey(k); err != nil { return nil, err @@ -87,7 +87,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { if key, ok := currentKeys[value]; ok { return key, nil } - return nil, EnrollmentKeyErrors.NoKeyFound + return nil, EnrollmentErrors.NoKeyFound } // DeleteEnrollmentKey - delete's a given enrollment key by value @@ -104,7 +104,7 @@ func DeleteEnrollmentKey(value string) error { func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { key, err := decrementEnrollmentKey(k.Value) if err != nil { - if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) { + if errors.Is(err, EnrollmentErrors.NoUsesRemaining) { return k.IsValid() } } else { @@ -118,7 +118,7 @@ func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { // and attaches it to the Token field on the struct func Tokenize(k *models.EnrollmentKey, serverAddr string) error { if len(serverAddr) == 0 || k == nil { - return EnrollmentKeyErrors.FailedToTokenize + return EnrollmentErrors.FailedToTokenize } newToken := models.EnrollmentToken{ Server: serverAddr, @@ -136,7 +136,7 @@ func Tokenize(k *models.EnrollmentKey, serverAddr string) error { // and finds the associated enrollment key func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { if len(b64Token) == 0 { - return nil, EnrollmentKeyErrors.FailedToDeTokenize + return nil, EnrollmentErrors.FailedToDeTokenize } tokenData, err := b64.StdEncoding.DecodeString(b64Token) if err != nil { @@ -164,7 +164,7 @@ func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { return nil, err } if k.UsesRemaining == 0 { - return nil, EnrollmentKeyErrors.NoUsesRemaining + return nil, EnrollmentErrors.NoUsesRemaining } k.UsesRemaining = k.UsesRemaining - 1 if err = upsertEnrollmentKey(k); err != nil { @@ -176,7 +176,7 @@ func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { func upsertEnrollmentKey(k *models.EnrollmentKey) error { if k == nil { - return EnrollmentKeyErrors.InvalidKey + return EnrollmentErrors.InvalidKey } data, err := json.Marshal(k) if err != nil { diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index cdfc6712..ace8ef9a 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -16,7 +16,7 @@ func TestCreateEnrollmentKey(t *testing.T) { newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) assert.Nil(t, newKey) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate) + assert.Equal(t, err, EnrollmentErrors.InvalidCreate) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) @@ -69,12 +69,12 @@ func TestDelete_EnrollmentKey(t *testing.T) { oldKey, err := GetEnrollmentKey(newKey.Value) assert.Nil(t, oldKey) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { err := DeleteEnrollmentKey("notakey") assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) removeAllEnrollments() } @@ -97,7 +97,7 @@ func TestDecrement_EnrollmentKey(t *testing.T) { assert.Equal(t, newKey.UsesRemaining, 0) _, err := decrementEnrollmentKey(newKey.Value) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining) + assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining) }) removeAllEnrollments() @@ -151,12 +151,12 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) { err := Tokenize(nil, "ServerAddress") assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.FailedToTokenize) + assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) }) t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) { err := Tokenize(newKey, "") assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.FailedToTokenize) + assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) }) t.Run("Can_Tokenize", func(t *testing.T) { @@ -185,13 +185,13 @@ func TestDeTokenize_EnrollmentKeys(t *testing.T) { value, err := DeTokenize("") assert.Nil(t, value) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.FailedToDeTokenize) + assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize) }) t.Run("Can_Not_Find_Key", func(t *testing.T) { value, err := DeTokenize(b64Value) assert.Nil(t, value) assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) }) t.Run("Can_DeTokenize", func(t *testing.T) { err := Tokenize(newKey, serverAddr) From 08248e1b35b771ad1d5cec9f51489dee010b9c05 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 16 Feb 2023 19:34:25 -0500 Subject: [PATCH 13/23] added log --- controllers/enrollmentkeys.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 0c346c10..98d1629a 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -41,7 +41,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { } for i := range currentKeys { currentKey := currentKeys[i] - if err = logic.Tokenize(currentKey, servercfg.GetServer()); err != nil { + if err = logic.Tokenize(currentKey, servercfg.GetAPIHost()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return @@ -111,7 +111,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { return } - if err = logic.Tokenize(newEnrollmentKey, servercfg.GetServer()); err != nil { + if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return @@ -135,6 +135,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { func handleHostRegister(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) token := params["token"] + logger.Log(0, "received registration attempt with token", token) // check if token exists enrollmentKey, err := logic.DeTokenize(token) if err != nil { @@ -185,7 +186,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { // ready the response server := servercfg.GetServerInfo() server.TrafficKey = key - logger.Log(2, r.Header.Get("user"), "deleted enrollment key", token) + logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&server) // notify host of changes, peer and node updates From 6b30cef9688ebec52f7cf849edcdf8185cb2a41e Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 17 Feb 2023 11:32:02 -0500 Subject: [PATCH 14/23] handled node additions in more elegant manner --- controllers/enrollmentkeys.go | 32 +++++------------------- controllers/hosts.go | 12 +++------ logic/hostactions/hostactions.go | 43 ++++++++++++++++++++++++++++++++ logic/hosts.go | 2 +- models/host.go | 2 ++ mq/handlers.go | 8 ++++++ mq/publishers.go | 3 +++ 7 files changed, 66 insertions(+), 36 deletions(-) create mode 100644 logic/hostactions/hostactions.go diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 98d1629a..7c3ad103 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -9,8 +9,8 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" ) @@ -182,7 +182,6 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - // ready the response server := servercfg.GetServerInfo() server.TrafficKey = key @@ -196,15 +195,6 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { // run through networks and send a host update func checkNetRegAndHostUpdate(networks []string, h *models.Host) { // publish host update through MQ - if servercfg.IsMessageQueueBackend() { - if err := mq.HostUpdate(&models.HostUpdate{ - Action: models.UpdateHost, - Host: *h, - }); err != nil { - logger.Log(0, "failed to send host update after registration:", h.ID.String(), err.Error()) - } - } - for i := range networks { if ok, _ := logic.NetworkExists(networks[i]); ok { newNode, err := logic.UpdateHostNetwork(h, networks[i], true) @@ -213,21 +203,11 @@ func checkNetRegAndHostUpdate(networks []string, h *models.Host) { continue } logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) - if servercfg.IsMessageQueueBackend() { - if err = mq.HostUpdate(&models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *h, - Node: *newNode, - }); err != nil { - logger.Log(0, "failed to send host update to", h.ID.String(), networks[i], err.Error()) - } - } - } - } - - if servercfg.IsMessageQueueBackend() { - if err := mq.PublishPeerUpdate(); err != nil { - logger.Log(0, "failed to publish peer update after host registration -", err.Error()) + hostactions.AddAction(models.HostUpdate{ + Action: models.JoinHostToNetwork, + Host: *h, + Node: *newNode, + }) } } } diff --git a/controllers/hosts.go b/controllers/hosts.go index b3d25ce6..29c6f216 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "golang.org/x/crypto/bcrypt" @@ -230,18 +231,11 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { return } logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name) - if err = mq.HostUpdate(&models.HostUpdate{ + hostactions.AddAction(models.HostUpdate{ Action: models.JoinHostToNetwork, Host: *currHost, Node: *newNode, - }); err != nil { - logger.Log(0, r.Header.Get("user"), "failed to update host to join network:", hostid, network, err.Error()) - } - go func() { // notify of peer change - if err := mq.PublishPeerUpdate(); err != nil { - logger.Log(1, "error publishing peer update ", err.Error()) - } - }() + }) logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network)) w.WriteHeader(http.StatusOK) diff --git a/logic/hostactions/hostactions.go b/logic/hostactions/hostactions.go new file mode 100644 index 00000000..5b81d188 --- /dev/null +++ b/logic/hostactions/hostactions.go @@ -0,0 +1,43 @@ +package hostactions + +import ( + "sync" + + "github.com/gravitl/netmaker/models" +) + +// nodeActionHandler - handles the storage of host action updates +var nodeActionHandler sync.Map + +// AddAction - adds a host action to a host's list to be retrieved from broker update +func AddAction(hu models.HostUpdate) { + currentRecords, ok := nodeActionHandler.Load(hu.Host.ID.String()) + if !ok { // no list exists yet + nodeActionHandler.Store(hu.Host.ID.String(), []models.HostUpdate{hu}) + } else { // list exists, append to it + currentList := currentRecords.([]models.HostUpdate) + currentList = append(currentList, hu) + nodeActionHandler.Store(hu.Host.ID.String(), currentList) + } +} + +// GetAction - gets an action if exists +func GetAction(id string) *models.HostUpdate { + currentRecords, ok := nodeActionHandler.Load(id) + if !ok { + return nil + } + currentList := currentRecords.([]models.HostUpdate) + if len(currentList) > 0 { + hu := currentList[0] + nodeActionHandler.Store(hu.Host.ID.String(), currentList[1:]) + return &hu + } + return nil +} + +// [hostID][NodeAction1, NodeAction2] +// host receives nodeaction1 +// host responds with ACK or something +// mq then sends next action in list, NodeAction2 +// host responds, list is empty, finished diff --git a/logic/hosts.go b/logic/hosts.go index de05caa7..716f7b72 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -90,7 +90,7 @@ func CreateHost(h *models.Host) error { if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) { return ErrHostExists } - //encrypt that password so we never see it + // encrypt that password so we never see it hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5) if err != nil { return err diff --git a/models/host.go b/models/host.go index 743ce431..ce470591 100644 --- a/models/host.go +++ b/models/host.go @@ -74,6 +74,8 @@ const ( DeleteHost = "DELETE_HOST" // JoinHostToNetwork - constant for host network join action JoinHostToNetwork = "JOIN_HOST_TO_NETWORK" + // Acknowledgement - ACK response for hosts + Acknowledgement = "ACK" ) // HostUpdate - struct for host update diff --git a/mq/handlers.go b/mq/handlers.go index 49555238..aa3395dc 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -9,6 +9,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" @@ -144,6 +145,13 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String())) var sendPeerUpdate bool switch hostUpdate.Action { + case models.Acknowledgement: + hu := hostactions.GetAction(currentHost.ID.String()) + if err = HostUpdate(hu); err != nil { + logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) + return + } + sendPeerUpdate = true case models.UpdateHost: sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost) err := logic.UpsertHost(currentHost) diff --git a/mq/publishers.go b/mq/publishers.go index ae71e79b..0e9cea26 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -40,6 +40,9 @@ func PublishSingleHostUpdate(host *models.Host) error { if err != nil { return err } + if len(peerUpdate.Peers) == 0 { // no peers to send + return nil + } if host.ProxyEnabled { proxyUpdate, err := logic.GetProxyUpdateForHost(host) if err != nil { From a1f5d73a587c5e9f23e76af93d3c3b3dd366017d Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 17 Feb 2023 11:39:30 -0500 Subject: [PATCH 15/23] added request ack --- controllers/hosts.go | 7 +++++++ models/host.go | 2 ++ mq/handlers.go | 11 +++++++---- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/controllers/hosts.go b/controllers/hosts.go index 29c6f216..bad463a5 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -13,6 +13,7 @@ import ( "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" ) @@ -236,6 +237,12 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { Host: *currHost, Node: *newNode, }) + if servercfg.IsMessageQueueBackend() { + mq.HostUpdate(&models.HostUpdate{ + Action: models.RequestAck, + Host: *currHost, + }) + } logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network)) w.WriteHeader(http.StatusOK) diff --git a/models/host.go b/models/host.go index ce470591..86991198 100644 --- a/models/host.go +++ b/models/host.go @@ -76,6 +76,8 @@ const ( JoinHostToNetwork = "JOIN_HOST_TO_NETWORK" // Acknowledgement - ACK response for hosts Acknowledgement = "ACK" + // RequestAck - request an ACK + RequestAck = "REQ_ACK" ) // HostUpdate - struct for host update diff --git a/mq/handlers.go b/mq/handlers.go index aa3395dc..b6954cbc 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -147,11 +147,13 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { switch hostUpdate.Action { case models.Acknowledgement: hu := hostactions.GetAction(currentHost.ID.String()) - if err = HostUpdate(hu); err != nil { - logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) - return + if hu != nil { + if err = HostUpdate(hu); err != nil { + logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) + return + } + sendPeerUpdate = true } - sendPeerUpdate = true case models.UpdateHost: sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost) err := logic.UpsertHost(currentHost) @@ -170,6 +172,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { } sendPeerUpdate = true } + if sendPeerUpdate { err := PublishPeerUpdate() if err != nil { From 9103efa88f1aabe145416025e3f7e0afcab94692 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 17 Feb 2023 11:54:25 -0500 Subject: [PATCH 16/23] added request ack on register --- controllers/enrollmentkeys.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 7c3ad103..9e300adc 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -11,6 +11,7 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" ) @@ -210,4 +211,10 @@ func checkNetRegAndHostUpdate(networks []string, h *models.Host) { }) } } + if servercfg.IsMessageQueueBackend() { + mq.HostUpdate(&models.HostUpdate{ + Action: models.RequestAck, + Host: *h, + }) + } } From 541e232ad79b7b1f84cb48c60cd939c899f61f2f Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 17 Feb 2023 12:09:18 -0500 Subject: [PATCH 17/23] update comments --- controllers/enrollmentkeys.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 9e300adc..fc1809d5 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -24,7 +24,7 @@ func enrollmentKeyHandlers(r *mux.Router) { // swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys // -// Lists all hosts. +// Lists all EnrollmentKeys for admins. // // Schemes: https // @@ -56,7 +56,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { // swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey // -// Deletes a Netclient host from Netmaker server. +// Deletes an EnrollmentKey from Netmaker server. // // Schemes: https // @@ -122,9 +122,9 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(newEnrollmentKey) } -// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys deleteEnrollmentKey +// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister // -// Deletes a Netclient host from Netmaker server. +// Handles a Netclient registration with server and add nodes accordingly. // // Schemes: https // @@ -132,7 +132,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { // oauth // // Responses: -// 200: hostRegisterResponse +// 200: handleHostRegisterResponse func handleHostRegister(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) token := params["token"] From 3ab4b5be331fea242c8991e820e9bb4c06800a34 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 17 Feb 2023 12:13:38 -0500 Subject: [PATCH 18/23] fixing a pointless check I didn't write --- logic/host_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/logic/host_test.go b/logic/host_test.go index 94ce4b9d..9f178c4e 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -2,6 +2,7 @@ package logic import ( "context" + "fmt" "net" "testing" @@ -17,7 +18,8 @@ func TestCheckPorts(t *testing.T) { peerUpdate := make(chan *models.Node) go ManageZombies(context.Background(), peerUpdate) go func() { - for _ = range peerUpdate { + for y := range peerUpdate { + fmt.Printf("Pointless %v\n", y) //do nothing } }() From 0335e258addcf4cfb4dd04ccee1ee93bc86bbf08 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 24 Feb 2023 12:08:32 -0500 Subject: [PATCH 19/23] added TODO comment and allowed using enrollment key more than once --- controllers/enrollmentkeys.go | 31 +++++++++++++++++++++++-------- logic/hostactions/hostactions.go | 7 +------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index fc1809d5..2ac31f18 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -152,9 +152,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + hostExists := false // check if host already exists - if ok := logic.HostExists(&newHost); ok { - logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register") + if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks") logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest")) return } @@ -176,13 +177,27 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest")) return } - // register host - logic.CheckHostPorts(&newHost) - if err = logic.CreateHost(&newHost); err != nil { - logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return + if !hostExists { + // register host + logic.CheckHostPorts(&newHost) + if err = logic.CreateHost(&newHost); err != nil { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } else { + // need to revise the list of networks from key + // based on the ones host currently has + var networksToAdd = []string{} + currentNets := logic.GetHostNetworks(newHost.ID.String()) + for _, newNet := range enrollmentKey.Networks { + if !logic.StringSliceContains(currentNets, newNet) { + networksToAdd = append(networksToAdd, newNet) + } + } + enrollmentKey.Networks = networksToAdd } + // ready the response server := servercfg.GetServerInfo() server.TrafficKey = key diff --git a/logic/hostactions/hostactions.go b/logic/hostactions/hostactions.go index 5b81d188..fa215c1c 100644 --- a/logic/hostactions/hostactions.go +++ b/logic/hostactions/hostactions.go @@ -22,6 +22,7 @@ func AddAction(hu models.HostUpdate) { } // GetAction - gets an action if exists +// TODO: may need to move to DB rather than sync map for HA func GetAction(id string) *models.HostUpdate { currentRecords, ok := nodeActionHandler.Load(id) if !ok { @@ -35,9 +36,3 @@ func GetAction(id string) *models.HostUpdate { } return nil } - -// [hostID][NodeAction1, NodeAction2] -// host receives nodeaction1 -// host responds with ACK or something -// mq then sends next action in list, NodeAction2 -// host responds, list is empty, finished From 92d0d12e8fb87084424faff10cf1d50ee9a834ec Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Fri, 24 Feb 2023 15:37:53 -0500 Subject: [PATCH 20/23] adjusted main to use one single context --- controllers/controller.go | 15 +++------- logic/nodes.go | 30 ------------------- main.go | 52 +++++++++++---------------------- mq/mq.go | 5 ++++ stun-server/stun-server.go | 59 +++++++++++++++++++------------------- 5 files changed, 56 insertions(+), 105 deletions(-) diff --git a/controllers/controller.go b/controllers/controller.go index e4d0ef50..7fa39889 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -4,11 +4,8 @@ import ( "context" "fmt" "net/http" - "os" - "os/signal" "strings" "sync" - "syscall" "time" "github.com/gorilla/handlers" @@ -33,7 +30,7 @@ var HttpHandlers = []interface{}{ } // HandleRESTRequests - handles the rest requests -func HandleRESTRequests(wg *sync.WaitGroup) { +func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() r := mux.NewRouter() @@ -59,18 +56,14 @@ func HandleRESTRequests(wg *sync.WaitGroup) { }() logger.Log(0, "REST Server successfully started on port ", port, " (REST)") - // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) - // Ignore other incoming signals - ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt) - defer stop() - // Block main routine until a signal is received // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running <-ctx.Done() - // After receiving CTRL+C Properly stop the server logger.Log(0, "Stopping the REST server...") + if err := srv.Shutdown(context.TODO()); err != nil { + logger.Log(0, "REST shutdown error occurred -", err.Error()) + } logger.Log(0, "REST Server closed.") logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) - srv.Shutdown(context.TODO()) } diff --git a/logic/nodes.go b/logic/nodes.go index 45a0ac2c..ba4f4687 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -1,7 +1,6 @@ package logic import ( - "context" "encoding/json" "errors" "fmt" @@ -421,35 +420,6 @@ func updateProNodeACLS(node *models.Node) error { return nil } -func PurgePendingNodes(ctx context.Context) { - ticker := time.NewTicker(NodePurgeCheckTime) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - nodes, err := GetAllNodes() - if err != nil { - logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error()) - continue - } - for _, node := range nodes { - if node.PendingDelete { - modified := node.LastModified - if time.Since(modified) > NodePurgeTime { - if err := DeleteNode(&node, true); err != nil { - logger.Log(0, "failed to purge node", node.ID.String(), err.Error()) - } else { - logger.Log(0, "purged node ", node.ID.String()) - } - } - } - } - } - } -} - // createNode - creates a node in database func createNode(node *models.Node) error { host, err := GetHost(node.HostID.String()) diff --git a/main.go b/main.go index 53cf4712..0cda211e 100644 --- a/main.go +++ b/main.go @@ -36,12 +36,16 @@ func main() { setupConfig(*absoluteConfigPath) servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo - // fmt.Println(models.ProLogo()) - initialize() // initial db and acls; gen cert if required + initialize() // initial db and acls setGarbageCollection() setVerbosity() defer database.CloseDB() - startControllers() // start the api endpoint and mq + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt) + defer stop() + var waitGroup sync.WaitGroup + startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun + <-ctx.Done() + waitGroup.Wait() } func setupConfig(absoluteConfigPath string) { @@ -110,8 +114,7 @@ func initialize() { // Client Mode Prereq Check } } -func startControllers() { - var waitnetwork sync.WaitGroup +func startControllers(wg *sync.WaitGroup, ctx context.Context) { if servercfg.IsDNSMode() { err := logic.SetDNS() if err != nil { @@ -127,13 +130,13 @@ func startControllers() { logger.FatalLog("Unable to Set host. Exiting...", err.Error()) } } - waitnetwork.Add(1) - go controller.HandleRESTRequests(&waitnetwork) + wg.Add(1) + go controller.HandleRESTRequests(wg, ctx) } //Run MessageQueue if servercfg.IsMessageQueueBackend() { - waitnetwork.Add(1) - go runMessageQueue(&waitnetwork) + wg.Add(1) + go runMessageQueue(wg, ctx) } if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() { @@ -141,34 +144,17 @@ func startControllers() { } // starts the stun server - waitnetwork.Add(1) - go stunserver.Start(&waitnetwork) - if servercfg.IsProxyEnabled() { - - waitnetwork.Add(1) - go func() { - defer waitnetwork.Done() - _, cancel := context.WithCancel(context.Background()) - waitnetwork.Add(1) - - //go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost()) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() - }() - } - - waitnetwork.Wait() + wg.Add(1) + go stunserver.Start(wg, ctx) } // Should we be using a context vice a waitgroup???????????? -func runMessageQueue(wg *sync.WaitGroup) { +func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() brokerHost, secure := servercfg.GetMessageQueueEndpoint() logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure)) mq.SetupMQTT() - ctx, cancel := context.WithCancel(context.Background()) + defer mq.CloseClient() go mq.Keepalive(ctx) go func() { peerUpdate := make(chan *models.Node) @@ -179,11 +165,7 @@ func runMessageQueue(wg *sync.WaitGroup) { } } }() - go logic.PurgePendingNodes(ctx) - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() + <-ctx.Done() logger.Log(0, "Message Queue shutting down") } diff --git a/mq/mq.go b/mq/mq.go index 056dc925..9fabc90f 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -100,3 +100,8 @@ func Keepalive(ctx context.Context) { func IsConnected() bool { return mqclient != nil && mqclient.IsConnected() } + +// CloseClient - function to close the mq connection from server +func CloseClient() { + mqclient.Disconnect(250) +} diff --git a/stun-server/stun-server.go b/stun-server/stun-server.go index 7e4b768e..9bc22b14 100644 --- a/stun-server/stun-server.go +++ b/stun-server/stun-server.go @@ -4,11 +4,8 @@ import ( "context" "fmt" "net" - "os" - "os/signal" "strings" "sync" - "syscall" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/servercfg" @@ -23,7 +20,6 @@ import ( // backwards compatibility with RFC 3489. type Server struct { Addr string - Ctx context.Context } var ( @@ -60,48 +56,58 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error { ) } -func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { +func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error { if c == nil { return nil } + go func(ctx context.Context) { + <-ctx.Done() + if c != nil { + // kill connection on server shutdown + c.Close() + } + }(ctx) + buf := make([]byte, 1024) - n, addr, err := c.ReadFrom(buf) + n, addr, err := c.ReadFrom(buf) // this be blocky af if err != nil { - logger.Log(1, "ReadFrom: %v", err.Error()) + if !strings.Contains(err.Error(), "use of closed network connection") { + logger.Log(1, "STUN read error:", err.Error()) + } return nil } + if _, err = req.Write(buf[:n]); err != nil { - logger.Log(1, "Write: %v", err.Error()) + logger.Log(1, "STUN write error:", err.Error()) return err } if err = basicProcess(addr, buf[:n], req, res); err != nil { if err == errNotSTUNMessage { return nil } - logger.Log(1, "basicProcess: %v", err.Error()) + logger.Log(1, "STUN process error:", err.Error()) return nil } _, err = c.WriteTo(res.Raw, addr) if err != nil { - logger.Log(1, "WriteTo: %v", err.Error()) + logger.Log(1, "STUN response write error", err.Error()) } return err } // Serve reads packets from connections and responds to BINDING requests. -func (s *Server) serve(c net.PacketConn) error { +func (s *Server) serve(c net.PacketConn, ctx context.Context) error { var ( res = new(stun.Message) req = new(stun.Message) ) for { select { - case <-s.Ctx.Done(): - logger.Log(0, "Shutting down stun server...") - c.Close() + case <-ctx.Done(): + logger.Log(0, "shut down STUN server") return nil default: - if err := s.serveConn(c, res, req); err != nil { + if err := s.serveConn(c, res, req, ctx); err != nil { logger.Log(1, "serve: %v", err.Error()) continue } @@ -119,9 +125,8 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error { } s := &Server{ Addr: laddr, - Ctx: ctx, } - return s.serve(c) + return s.serve(c, ctx) } func normalize(address string) string { @@ -135,19 +140,15 @@ func normalize(address string) string { } // Start - starts the stun server -func Start(wg *sync.WaitGroup) { - ctx, cancel := context.WithCancel(context.Background()) - go func(wg *sync.WaitGroup) { - defer wg.Done() - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGTERM, os.Interrupt) - <-quit - cancel() - }(wg) +func Start(wg *sync.WaitGroup, ctx context.Context) { + defer wg.Done() normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort())) logger.Log(0, "netmaker-stun listening on", normalized, "via udp") - err := listenUDPAndServe(ctx, "udp", normalized) - if err != nil { - logger.Log(0, "failed to start stun server: ", err.Error()) + if err := listenUDPAndServe(ctx, "udp", normalized); err != nil { + if strings.Contains(err.Error(), "closed network connection") { + logger.Log(0, "shutdown STUN server") + } else { + logger.Log(0, "server: ", err.Error()) + } } } From 977c9c8c199058daa0d4b40f4217bdf3bf957982 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Mon, 27 Feb 2023 12:32:07 -0500 Subject: [PATCH 21/23] send peer update after request + fix pass update issue --- controllers/enrollmentkeys.go | 3 +++ logic/hosts.go | 5 +++++ mq/handlers.go | 6 +++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 2ac31f18..9294a281 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -231,5 +231,8 @@ func checkNetRegAndHostUpdate(networks []string, h *models.Host) { Action: models.RequestAck, Host: *h, }) + if err := mq.PublishPeerUpdate(); err != nil { + logger.Log(0, "failed to publish peer update during registration -", err.Error()) + } } } diff --git a/logic/hosts.go b/logic/hosts.go index 716f7b72..b504c233 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -237,6 +237,11 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error { return err } h.Nodes = append(h.Nodes, n.ID.String()) + currentHost, err := GetHost(h.ID.String()) + if err != nil { + return err + } + h.HostPass = currentHost.HostPass return UpsertHost(h) } diff --git a/mq/handlers.go b/mq/handlers.go index 3f63cbb2..14f676b3 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -151,8 +151,12 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { if err = HostUpdate(hu); err != nil { logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) return + } else { + if err = PublishSingleHostPeerUpdate(currentHost, nil); err != nil { + logger.Log(0, "failed peers publish after join acknowledged", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) + return + } } - sendPeerUpdate = true } case models.UpdateHost: sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost) From 2749e7311b2d4f488d7fbe0d1547e1619707d7f3 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Mon, 27 Feb 2023 13:36:32 -0500 Subject: [PATCH 22/23] Revert "adjusted main to use one single context" This reverts commit 92d0d12e8fb87084424faff10cf1d50ee9a834ec. --- controllers/controller.go | 15 +++++++--- logic/nodes.go | 30 +++++++++++++++++++ main.go | 52 ++++++++++++++++++++++----------- mq/mq.go | 5 ---- stun-server/stun-server.go | 59 +++++++++++++++++++------------------- 5 files changed, 105 insertions(+), 56 deletions(-) diff --git a/controllers/controller.go b/controllers/controller.go index 7fa39889..e4d0ef50 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -4,8 +4,11 @@ import ( "context" "fmt" "net/http" + "os" + "os/signal" "strings" "sync" + "syscall" "time" "github.com/gorilla/handlers" @@ -30,7 +33,7 @@ var HttpHandlers = []interface{}{ } // HandleRESTRequests - handles the rest requests -func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { +func HandleRESTRequests(wg *sync.WaitGroup) { defer wg.Done() r := mux.NewRouter() @@ -56,14 +59,18 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { }() logger.Log(0, "REST Server successfully started on port ", port, " (REST)") + // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) + // Ignore other incoming signals + ctx, stop := signal.NotifyContext(context.TODO(), syscall.SIGTERM, os.Interrupt) + defer stop() + // Block main routine until a signal is received // As long as user doesn't press CTRL+C a message is not passed and our main routine keeps running <-ctx.Done() + // After receiving CTRL+C Properly stop the server logger.Log(0, "Stopping the REST server...") - if err := srv.Shutdown(context.TODO()); err != nil { - logger.Log(0, "REST shutdown error occurred -", err.Error()) - } logger.Log(0, "REST Server closed.") logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay))) + srv.Shutdown(context.TODO()) } diff --git a/logic/nodes.go b/logic/nodes.go index ba4f4687..45a0ac2c 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -1,6 +1,7 @@ package logic import ( + "context" "encoding/json" "errors" "fmt" @@ -420,6 +421,35 @@ func updateProNodeACLS(node *models.Node) error { return nil } +func PurgePendingNodes(ctx context.Context) { + ticker := time.NewTicker(NodePurgeCheckTime) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + nodes, err := GetAllNodes() + if err != nil { + logger.Log(0, "PurgePendingNodes failed to retrieve nodes", err.Error()) + continue + } + for _, node := range nodes { + if node.PendingDelete { + modified := node.LastModified + if time.Since(modified) > NodePurgeTime { + if err := DeleteNode(&node, true); err != nil { + logger.Log(0, "failed to purge node", node.ID.String(), err.Error()) + } else { + logger.Log(0, "purged node ", node.ID.String()) + } + } + } + } + } + } +} + // createNode - creates a node in database func createNode(node *models.Node) error { host, err := GetHost(node.HostID.String()) diff --git a/main.go b/main.go index 0cda211e..53cf4712 100644 --- a/main.go +++ b/main.go @@ -36,16 +36,12 @@ func main() { setupConfig(*absoluteConfigPath) servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo - initialize() // initial db and acls + // fmt.Println(models.ProLogo()) + initialize() // initial db and acls; gen cert if required setGarbageCollection() setVerbosity() defer database.CloseDB() - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt) - defer stop() - var waitGroup sync.WaitGroup - startControllers(&waitGroup, ctx) // start the api endpoint and mq and stun - <-ctx.Done() - waitGroup.Wait() + startControllers() // start the api endpoint and mq } func setupConfig(absoluteConfigPath string) { @@ -114,7 +110,8 @@ func initialize() { // Client Mode Prereq Check } } -func startControllers(wg *sync.WaitGroup, ctx context.Context) { +func startControllers() { + var waitnetwork sync.WaitGroup if servercfg.IsDNSMode() { err := logic.SetDNS() if err != nil { @@ -130,13 +127,13 @@ func startControllers(wg *sync.WaitGroup, ctx context.Context) { logger.FatalLog("Unable to Set host. Exiting...", err.Error()) } } - wg.Add(1) - go controller.HandleRESTRequests(wg, ctx) + waitnetwork.Add(1) + go controller.HandleRESTRequests(&waitnetwork) } //Run MessageQueue if servercfg.IsMessageQueueBackend() { - wg.Add(1) - go runMessageQueue(wg, ctx) + waitnetwork.Add(1) + go runMessageQueue(&waitnetwork) } if !servercfg.IsRestBackend() && !servercfg.IsMessageQueueBackend() { @@ -144,17 +141,34 @@ func startControllers(wg *sync.WaitGroup, ctx context.Context) { } // starts the stun server - wg.Add(1) - go stunserver.Start(wg, ctx) + waitnetwork.Add(1) + go stunserver.Start(&waitnetwork) + if servercfg.IsProxyEnabled() { + + waitnetwork.Add(1) + go func() { + defer waitnetwork.Done() + _, cancel := context.WithCancel(context.Background()) + waitnetwork.Add(1) + + //go nmproxy.Start(ctx, logic.ProxyMgmChan, servercfg.GetAPIHost()) + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + <-quit + cancel() + }() + } + + waitnetwork.Wait() } // Should we be using a context vice a waitgroup???????????? -func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { +func runMessageQueue(wg *sync.WaitGroup) { defer wg.Done() brokerHost, secure := servercfg.GetMessageQueueEndpoint() logger.Log(0, "connecting to mq broker at", brokerHost, "with TLS?", fmt.Sprintf("%v", secure)) mq.SetupMQTT() - defer mq.CloseClient() + ctx, cancel := context.WithCancel(context.Background()) go mq.Keepalive(ctx) go func() { peerUpdate := make(chan *models.Node) @@ -165,7 +179,11 @@ func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { } } }() - <-ctx.Done() + go logic.PurgePendingNodes(ctx) + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + <-quit + cancel() logger.Log(0, "Message Queue shutting down") } diff --git a/mq/mq.go b/mq/mq.go index 9fabc90f..056dc925 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -100,8 +100,3 @@ func Keepalive(ctx context.Context) { func IsConnected() bool { return mqclient != nil && mqclient.IsConnected() } - -// CloseClient - function to close the mq connection from server -func CloseClient() { - mqclient.Disconnect(250) -} diff --git a/stun-server/stun-server.go b/stun-server/stun-server.go index 9bc22b14..7e4b768e 100644 --- a/stun-server/stun-server.go +++ b/stun-server/stun-server.go @@ -4,8 +4,11 @@ import ( "context" "fmt" "net" + "os" + "os/signal" "strings" "sync" + "syscall" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/servercfg" @@ -20,6 +23,7 @@ import ( // backwards compatibility with RFC 3489. type Server struct { Addr string + Ctx context.Context } var ( @@ -56,58 +60,48 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error { ) } -func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message, ctx context.Context) error { +func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { if c == nil { return nil } - go func(ctx context.Context) { - <-ctx.Done() - if c != nil { - // kill connection on server shutdown - c.Close() - } - }(ctx) - buf := make([]byte, 1024) - n, addr, err := c.ReadFrom(buf) // this be blocky af + n, addr, err := c.ReadFrom(buf) if err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - logger.Log(1, "STUN read error:", err.Error()) - } + logger.Log(1, "ReadFrom: %v", err.Error()) return nil } - if _, err = req.Write(buf[:n]); err != nil { - logger.Log(1, "STUN write error:", err.Error()) + logger.Log(1, "Write: %v", err.Error()) return err } if err = basicProcess(addr, buf[:n], req, res); err != nil { if err == errNotSTUNMessage { return nil } - logger.Log(1, "STUN process error:", err.Error()) + logger.Log(1, "basicProcess: %v", err.Error()) return nil } _, err = c.WriteTo(res.Raw, addr) if err != nil { - logger.Log(1, "STUN response write error", err.Error()) + logger.Log(1, "WriteTo: %v", err.Error()) } return err } // Serve reads packets from connections and responds to BINDING requests. -func (s *Server) serve(c net.PacketConn, ctx context.Context) error { +func (s *Server) serve(c net.PacketConn) error { var ( res = new(stun.Message) req = new(stun.Message) ) for { select { - case <-ctx.Done(): - logger.Log(0, "shut down STUN server") + case <-s.Ctx.Done(): + logger.Log(0, "Shutting down stun server...") + c.Close() return nil default: - if err := s.serveConn(c, res, req, ctx); err != nil { + if err := s.serveConn(c, res, req); err != nil { logger.Log(1, "serve: %v", err.Error()) continue } @@ -125,8 +119,9 @@ func listenUDPAndServe(ctx context.Context, serverNet, laddr string) error { } s := &Server{ Addr: laddr, + Ctx: ctx, } - return s.serve(c, ctx) + return s.serve(c) } func normalize(address string) string { @@ -140,15 +135,19 @@ func normalize(address string) string { } // Start - starts the stun server -func Start(wg *sync.WaitGroup, ctx context.Context) { - defer wg.Done() +func Start(wg *sync.WaitGroup) { + ctx, cancel := context.WithCancel(context.Background()) + go func(wg *sync.WaitGroup) { + defer wg.Done() + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGTERM, os.Interrupt) + <-quit + cancel() + }(wg) normalized := normalize(fmt.Sprintf("0.0.0.0:%d", servercfg.GetStunPort())) logger.Log(0, "netmaker-stun listening on", normalized, "via udp") - if err := listenUDPAndServe(ctx, "udp", normalized); err != nil { - if strings.Contains(err.Error(), "closed network connection") { - logger.Log(0, "shutdown STUN server") - } else { - logger.Log(0, "server: ", err.Error()) - } + err := listenUDPAndServe(ctx, "udp", normalized) + if err != nil { + logger.Log(0, "failed to start stun server: ", err.Error()) } } From 9a7407f635b5d399a35317aca4bfe9f965e1bbbc Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Mon, 27 Feb 2023 19:12:07 -0500 Subject: [PATCH 23/23] updated logic to add new nodes --- controllers/enrollmentkeys.go | 8 ++++---- logic/hosts.go | 2 +- logic/peers.go | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 9294a281..73966187 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -197,7 +197,6 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { } enrollmentKey.Networks = networksToAdd } - // ready the response server := servercfg.GetServerInfo() server.TrafficKey = key @@ -212,10 +211,11 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { func checkNetRegAndHostUpdate(networks []string, h *models.Host) { // publish host update through MQ for i := range networks { - if ok, _ := logic.NetworkExists(networks[i]); ok { - newNode, err := logic.UpdateHostNetwork(h, networks[i], true) + network := networks[i] + if ok, _ := logic.NetworkExists(network); ok { + newNode, err := logic.UpdateHostNetwork(h, network, true) if err != nil { - logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, networks[i], err.Error()) + logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error()) continue } logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) diff --git a/logic/hosts.go b/logic/hosts.go index b504c233..0b995b16 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -236,12 +236,12 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error { if err != nil { return err } - h.Nodes = append(h.Nodes, n.ID.String()) currentHost, err := GetHost(h.ID.String()) if err != nil { return err } h.HostPass = currentHost.HostPass + h.Nodes = append(currentHost.Nodes, n.ID.String()) return UpsertHost(h) } diff --git a/logic/peers.go b/logic/peers.go index dbc9583c..b2b45b20 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -34,7 +34,6 @@ func GetProxyUpdateForHost(host *models.Host) (models.ProxyManagerPayload, error } else { logger.Log(0, "couldn't find relay host for: ", host.ID.String()) } - } if host.IsRelay { relayedHosts := GetRelayedHosts(host) @@ -142,9 +141,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models if deletedNode != nil { deletedNodes = append(deletedNodes, *deletedNode) } - logger.Log(1, "peer update for host ", host.ID.String()) + logger.Log(1, "peer update for host", host.ID.String()) peerIndexMap := make(map[string]int) for _, nodeID := range host.Nodes { + nodeID := nodeID node, err := GetNodeByID(nodeID) if err != nil { continue @@ -163,7 +163,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models } for _, peer := range currentPeers { peer := peer - if peer.ID == node.ID { + if peer.ID.String() == node.ID.String() { logger.Log(2, "peer update, skipping self") //skip yourself continue @@ -185,7 +185,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models continue } if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) { - log.Println("peer update, skipping node for acl") + logger.Log(2, "peer update, skipping node for acl") //skip if not permitted by acl continue }