diff --git a/config/config_test.go b/config/config_test.go index 2a8205fa..04c2c144 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,7 @@ -//Environment file for getting variables -//Currently the only thing it does is set the master password -//Should probably have it take over functions from OS such as port and mongodb connection details -//Reads from the config/environments/dev.yaml file by default +// Environment file for getting variables +// Currently the only thing it does is set the master password +// Should probably have it take over functions from OS such as port and mongodb connection details +// Reads from the config/environments/dev.yaml file by default package config import ( diff --git a/controllers/controller.go b/controllers/controller.go index 38dc3ecd..7fa39889 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -26,6 +26,7 @@ var HttpHandlers = []interface{}{ ipHandlers, loggerHandlers, hostHandlers, + enrollmentKeyHandlers, } // HandleRESTRequests - handles the rest requests diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go new file mode 100644 index 00000000..73966187 --- /dev/null +++ b/controllers/enrollmentkeys.go @@ -0,0 +1,238 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/servercfg" +) + +func enrollmentKeyHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost) +} + +// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys +// +// Lists all EnrollmentKeys for admins. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: getEnrollmentKeysSlice +func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { + currentKeys, err := logic.GetAllEnrollmentKeys() + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + for i := range currentKeys { + currentKey := currentKeys[i] + if err = logic.Tokenize(currentKey, servercfg.GetAPIHost()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } + // return JSON/API formatted keys + logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(currentKeys) +} + +// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey +// +// Deletes an EnrollmentKey from Netmaker server. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: deleteEnrollmentKeyResponse +func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + keyID := params["keyID"] + err := logic.DeleteEnrollmentKey(keyID) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID) + w.WriteHeader(http.StatusOK) +} + +// swagger:route POST /api/v1/enrollment-keys enrollmentKeys createEnrollmentKey +// +// Creates an EnrollmentKey for hosts to use on Netmaker server. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: createEnrollmentKeyResponse +func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { + + var enrollmentKeyBody models.APIEnrollmentKey + + err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var newTime time.Time + if enrollmentKeyBody.Expiration > 0 { + newTime = time.Unix(enrollmentKeyBody.Expiration, 0) + } + + newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + + if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logger.Log(2, r.Header.Get("user"), "created enrollment key") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(newEnrollmentKey) +} + +// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister +// +// Handles a Netclient registration with server and add nodes accordingly. +// +// Schemes: https +// +// Security: +// oauth +// +// Responses: +// 200: handleHostRegisterResponse +func handleHostRegister(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + token := params["token"] + logger.Log(0, "received registration attempt with token", token) + // check if token exists + enrollmentKey, err := logic.DeTokenize(token) + if err != nil { + logger.Log(0, "invalid enrollment key used", token, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + // get the host + var newHost models.Host + if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + hostExists := false + // check if host already exists + if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks") + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest")) + return + } + // version check + if !logic.IsVersionComptatible(newHost.Version) || newHost.TrafficKeyPublic == nil { + err := fmt.Errorf("incompatible netclient") + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + key, keyErr := logic.RetrievePublicTrafficKey() + if keyErr != nil { + logger.Log(0, "error retrieving key:", keyErr.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + // use the token + if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration") + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest")) + return + } + if !hostExists { + // register host + logic.CheckHostPorts(&newHost) + if err = logic.CreateHost(&newHost); err != nil { + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } else { + // need to revise the list of networks from key + // based on the ones host currently has + var networksToAdd = []string{} + currentNets := logic.GetHostNetworks(newHost.ID.String()) + for _, newNet := range enrollmentKey.Networks { + if !logic.StringSliceContains(currentNets, newNet) { + networksToAdd = append(networksToAdd, newNet) + } + } + enrollmentKey.Networks = networksToAdd + } + // ready the response + server := servercfg.GetServerInfo() + server.TrafficKey = key + logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&server) + // notify host of changes, peer and node updates + go checkNetRegAndHostUpdate(enrollmentKey.Networks, &newHost) +} + +// run through networks and send a host update +func checkNetRegAndHostUpdate(networks []string, h *models.Host) { + // publish host update through MQ + for i := range networks { + network := networks[i] + if ok, _ := logic.NetworkExists(network); ok { + newNode, err := logic.UpdateHostNetwork(h, network, true) + if err != nil { + logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error()) + continue + } + logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) + hostactions.AddAction(models.HostUpdate{ + Action: models.JoinHostToNetwork, + Host: *h, + Node: *newNode, + }) + } + } + if servercfg.IsMessageQueueBackend() { + mq.HostUpdate(&models.HostUpdate{ + Action: models.RequestAck, + Host: *h, + }) + if err := mq.PublishPeerUpdate(); err != nil { + logger.Log(0, "failed to publish peer update during registration -", err.Error()) + } + } +} diff --git a/controllers/hosts.go b/controllers/hosts.go index b3d25ce6..bad463a5 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -10,8 +10,10 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" ) @@ -230,18 +232,17 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { return } logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name) - if err = mq.HostUpdate(&models.HostUpdate{ + hostactions.AddAction(models.HostUpdate{ Action: models.JoinHostToNetwork, Host: *currHost, Node: *newNode, - }); err != nil { - logger.Log(0, r.Header.Get("user"), "failed to update host to join network:", hostid, network, err.Error()) + }) + if servercfg.IsMessageQueueBackend() { + mq.HostUpdate(&models.HostUpdate{ + Action: models.RequestAck, + Host: *currHost, + }) } - go func() { // notify of peer change - if err := mq.PublishPeerUpdate(); err != nil { - logger.Log(1, "error publishing peer update ", err.Error()) - } - }() logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network)) w.WriteHeader(http.StatusOK) diff --git a/database/database.go b/database/database.go index 283a829b..38d4fd40 100644 --- a/database/database.go +++ b/database/database.go @@ -57,6 +57,8 @@ const ( CACHE_TABLE_NAME = "cache" // HOSTS_TABLE_NAME - the table name for hosts HOSTS_TABLE_NAME = "hosts" + // ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys + ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys" // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -138,6 +140,7 @@ func createTables() { createTable(USER_GROUPS_TABLE_NAME) createTable(CACHE_TABLE_NAME) createTable(HOSTS_TABLE_NAME) + createTable(ENROLLMENT_KEYS_TABLE_NAME) } func createTable(tableName string) error { diff --git a/logic/accesskeys_test.go b/logic/accesskeys_test.go index 030508a9..fe5443fc 100644 --- a/logic/accesskeys_test.go +++ b/logic/accesskeys_test.go @@ -5,7 +5,6 @@ import "testing" func Test_genKeyName(t *testing.T) { for i := 0; i < 100; i++ { kname := genKeyName() - t.Log(kname) if len(kname) != 20 { t.Fatalf("improper length of key name, expected 20 got :%d", len(kname)) } @@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) { func Test_genKey(t *testing.T) { for i := 0; i < 100; i++ { kname := GenKey() - t.Log(kname) if len(kname) != 16 { t.Fatalf("improper length of key name, expected 16 got :%d", len(kname)) } diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go new file mode 100644 index 00000000..ec1d3c8f --- /dev/null +++ b/logic/enrollmentkey.go @@ -0,0 +1,221 @@ +package logic + +import ( + b64 "encoding/base64" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/ncutils" +) + +// EnrollmentErrors - struct for holding EnrollmentKey error messages +var EnrollmentErrors = struct { + InvalidCreate error + NoKeyFound error + InvalidKey error + NoUsesRemaining error + FailedToTokenize error + FailedToDeTokenize error +}{ + InvalidCreate: fmt.Errorf("invalid enrollment key created"), + NoKeyFound: fmt.Errorf("no enrollmentkey found"), + InvalidKey: fmt.Errorf("invalid key provided"), + NoUsesRemaining: fmt.Errorf("no uses remaining"), + FailedToTokenize: fmt.Errorf("failed to tokenize"), + FailedToDeTokenize: fmt.Errorf("failed to detokenize"), +} + +// CreateEnrollmentKey - creates a new enrollment key in db +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) { + newKeyID, err := getUniqueEnrollmentID() + if err != nil { + return nil, err + } + k = &models.EnrollmentKey{ + Value: newKeyID, + Expiration: time.Time{}, + UsesRemaining: 0, + Unlimited: unlimited, + Networks: []string{}, + Tags: []string{}, + } + if uses > 0 { + k.UsesRemaining = uses + } + if !expiration.IsZero() { + k.Expiration = expiration + } + if len(networks) > 0 { + k.Networks = networks + } + if len(tags) > 0 { + k.Tags = tags + } + if ok := k.Validate(); !ok { + return nil, EnrollmentErrors.InvalidCreate + } + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + return +} + +// GetAllEnrollmentKeys - fetches all enrollment keys from DB +func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + var currentKeysList = []*models.EnrollmentKey{} + for k := range currentKeys { + currentKeysList = append(currentKeysList, currentKeys[k]) + } + return currentKeysList, nil +} + +// GetEnrollmentKey - fetches a single enrollment key +// returns nil and error if not found +func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return nil, err + } + if key, ok := currentKeys[value]; ok { + return key, nil + } + return nil, EnrollmentErrors.NoKeyFound +} + +// DeleteEnrollmentKey - delete's a given enrollment key by value +func DeleteEnrollmentKey(value string) error { + _, err := GetEnrollmentKey(value) + if err != nil { + return err + } + return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) +} + +// TryToUseEnrollmentKey - checks first if key can be decremented +// returns true if it is decremented or isvalid +func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { + key, err := decrementEnrollmentKey(k.Value) + if err != nil { + if errors.Is(err, EnrollmentErrors.NoUsesRemaining) { + return k.IsValid() + } + } else { + k.UsesRemaining = key.UsesRemaining + return true + } + return false +} + +// Tokenize - tokenizes an enrollment key to be used via registration +// and attaches it to the Token field on the struct +func Tokenize(k *models.EnrollmentKey, serverAddr string) error { + if len(serverAddr) == 0 || k == nil { + return EnrollmentErrors.FailedToTokenize + } + newToken := models.EnrollmentToken{ + Server: serverAddr, + Value: k.Value, + } + data, err := json.Marshal(&newToken) + if err != nil { + return err + } + k.Token = b64.StdEncoding.EncodeToString(data) + return nil +} + +// DeTokenize - detokenizes a base64 encoded string +// and finds the associated enrollment key +func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { + if len(b64Token) == 0 { + return nil, EnrollmentErrors.FailedToDeTokenize + } + tokenData, err := b64.StdEncoding.DecodeString(b64Token) + if err != nil { + return nil, err + } + + var newToken models.EnrollmentToken + err = json.Unmarshal(tokenData, &newToken) + if err != nil { + return nil, err + } + k, err := GetEnrollmentKey(newToken.Value) + if err != nil { + return nil, err + } + return k, nil +} + +// == private == + +// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining +func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { + k, err := GetEnrollmentKey(value) + if err != nil { + return nil, err + } + if k.UsesRemaining == 0 { + return nil, EnrollmentErrors.NoUsesRemaining + } + k.UsesRemaining = k.UsesRemaining - 1 + if err = upsertEnrollmentKey(k); err != nil { + return nil, err + } + + return k, nil +} + +func upsertEnrollmentKey(k *models.EnrollmentKey) error { + if k == nil { + return EnrollmentErrors.InvalidKey + } + data, err := json.Marshal(k) + if err != nil { + return err + } + return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) +} + +func getUniqueEnrollmentID() (string, error) { + currentKeys, err := getEnrollmentKeysMap() + if err != nil { + return "", err + } + newID := ncutils.MakeRandomString(models.EnrollmentKeyLength) + for _, ok := currentKeys[newID]; ok; { + newID = ncutils.MakeRandomString(models.EnrollmentKeyLength) + } + return newID, nil +} + +func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) { + records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) + if err != nil { + if !database.IsEmptyRecord(err) { + return nil, err + } + } + if records == nil { + records = make(map[string]string) + } + currentKeys := make(map[string]*models.EnrollmentKey, 0) + if len(records) > 0 { + for k := range records { + var currentKey models.EnrollmentKey + if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil { + continue + } + currentKeys[k] = ¤tKey + } + } + return currentKeys, nil +} diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go new file mode 100644 index 00000000..ace8ef9a --- /dev/null +++ b/logic/enrollmentkey_test.go @@ -0,0 +1,206 @@ +package logic + +import ( + "testing" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/stretchr/testify/assert" +) + +func TestCreateEnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + t.Run("Can_Not_Create_Key", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false) + assert.Nil(t, newKey) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.InvalidCreate) + }) + t.Run("Can_Create_Key_Uses", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + assert.Nil(t, err) + assert.Equal(t, 1, newKey.UsesRemaining) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Time", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + }) + t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Networks) == 2) + }) + t.Run("Can_Create_Key_WithTags", func(t *testing.T) { + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true) + assert.Nil(t, err) + assert.True(t, newKey.IsValid()) + assert.True(t, len(newKey.Tags) == 2) + }) + + t.Run("Can_Get_List_of_Keys", func(t *testing.T) { + keys, err := GetAllEnrollmentKeys() + assert.Nil(t, err) + assert.True(t, len(keys) > 0) + for i := range keys { + assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength) + } + }) + removeAllEnrollments() +} + +func TestDelete_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + t.Run("Can_Delete_Key", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + err := DeleteEnrollmentKey(newKey.Value) + assert.Nil(t, err) + oldKey, err := GetEnrollmentKey(newKey.Value) + assert.Nil(t, oldKey) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) + }) + t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { + err := DeleteEnrollmentKey("notakey") + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) + }) + removeAllEnrollments() +} + +func TestDecrement_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + t.Run("Check_initial_uses", func(t *testing.T) { + assert.True(t, newKey.IsValid()) + assert.Equal(t, newKey.UsesRemaining, 1) + }) + t.Run("Check can decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 1) + k, err := decrementEnrollmentKey(newKey.Value) + assert.Nil(t, err) + newKey = k + }) + t.Run("Check can not decrement", func(t *testing.T) { + assert.Equal(t, newKey.UsesRemaining, 0) + _, err := decrementEnrollmentKey(newKey.Value) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining) + }) + + removeAllEnrollments() +} + +func TestUsability_EnrollmentKey(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true) + t.Run("Check if valid use key can be used", func(t *testing.T) { + assert.Equal(t, key1.UsesRemaining, 1) + ok := TryToUseEnrollmentKey(key1) + assert.True(t, ok) + assert.Equal(t, 0, key1.UsesRemaining) + }) + + t.Run("Check if valid time key can be used", func(t *testing.T) { + assert.True(t, !key2.Expiration.IsZero()) + ok := TryToUseEnrollmentKey(key2) + assert.True(t, ok) + }) + + t.Run("Check if valid unlimited key can be used", func(t *testing.T) { + assert.True(t, key3.Unlimited) + ok := TryToUseEnrollmentKey(key3) + assert.True(t, ok) + }) + + t.Run("check invalid key can not be used", func(t *testing.T) { + ok := TryToUseEnrollmentKey(key1) + assert.False(t, ok) + }) +} + +func removeAllEnrollments() { + database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) +} + +//Test that cheks if it can tokenize +//Test that cheks if it can't tokenize + +func TestTokenize_EnrollmentKeys(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" + const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" + const serverAddr = "api.myserver.com" + t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) { + err := Tokenize(nil, "ServerAddress") + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) + }) + t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) { + err := Tokenize(newKey, "") + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) + }) + + t.Run("Can_Tokenize", func(t *testing.T) { + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + assert.True(t, len(newKey.Token) > 0) + }) + + t.Run("Is_Correct_B64_Token", func(t *testing.T) { + newKey.Value = defaultValue + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + assert.Equal(t, newKey.Token, b64value) + }) + removeAllEnrollments() +} + +func TestDeTokenize_EnrollmentKeys(t *testing.T) { + database.InitializeDatabase() + defer database.CloseDB() + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true) + const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" + const serverAddr = "api.myserver.com" + + t.Run("Can_Not_DeTokenize", func(t *testing.T) { + value, err := DeTokenize("") + assert.Nil(t, value) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize) + }) + t.Run("Can_Not_Find_Key", func(t *testing.T) { + value, err := DeTokenize(b64Value) + assert.Nil(t, value) + assert.NotNil(t, err) + assert.Equal(t, err, EnrollmentErrors.NoKeyFound) + }) + t.Run("Can_DeTokenize", func(t *testing.T) { + err := Tokenize(newKey, serverAddr) + assert.Nil(t, err) + output, err := DeTokenize(newKey.Token) + assert.Nil(t, err) + assert.NotNil(t, output) + assert.Equal(t, newKey.Value, output.Value) + }) + + removeAllEnrollments() +} diff --git a/logic/host_test.go b/logic/host_test.go index fdde345e..9f178c4e 100644 --- a/logic/host_test.go +++ b/logic/host_test.go @@ -2,37 +2,28 @@ package logic import ( "context" + "fmt" "net" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/matryer/is" ) -func TestMain(m *testing.M) { +func TestCheckPorts(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - CreateAdmin(&models.User{ - UserName: "admin", - Password: "password", - IsAdmin: true, - Networks: []string{}, - Groups: []string{}, - }) peerUpdate := make(chan *models.Node) go ManageZombies(context.Background(), peerUpdate) go func() { - for update := range peerUpdate { + for y := range peerUpdate { + fmt.Printf("Pointless %v\n", y) //do nothing - logger.Log(3, "received node update", update.Action) } }() -} -func TestCheckPorts(t *testing.T) { h := models.Host{ ID: uuid.New(), EndpointIP: net.ParseIP("192.168.1.1"), diff --git a/logic/hostactions/hostactions.go b/logic/hostactions/hostactions.go new file mode 100644 index 00000000..fa215c1c --- /dev/null +++ b/logic/hostactions/hostactions.go @@ -0,0 +1,38 @@ +package hostactions + +import ( + "sync" + + "github.com/gravitl/netmaker/models" +) + +// nodeActionHandler - handles the storage of host action updates +var nodeActionHandler sync.Map + +// AddAction - adds a host action to a host's list to be retrieved from broker update +func AddAction(hu models.HostUpdate) { + currentRecords, ok := nodeActionHandler.Load(hu.Host.ID.String()) + if !ok { // no list exists yet + nodeActionHandler.Store(hu.Host.ID.String(), []models.HostUpdate{hu}) + } else { // list exists, append to it + currentList := currentRecords.([]models.HostUpdate) + currentList = append(currentList, hu) + nodeActionHandler.Store(hu.Host.ID.String(), currentList) + } +} + +// GetAction - gets an action if exists +// TODO: may need to move to DB rather than sync map for HA +func GetAction(id string) *models.HostUpdate { + currentRecords, ok := nodeActionHandler.Load(id) + if !ok { + return nil + } + currentList := currentRecords.([]models.HostUpdate) + if len(currentList) > 0 { + hu := currentList[0] + nodeActionHandler.Store(hu.Host.ID.String(), currentList[1:]) + return &hu + } + return nil +} diff --git a/logic/hosts.go b/logic/hosts.go index de05caa7..0b995b16 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -90,7 +90,7 @@ func CreateHost(h *models.Host) error { if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) { return ErrHostExists } - //encrypt that password so we never see it + // encrypt that password so we never see it hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5) if err != nil { return err @@ -236,7 +236,12 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error { if err != nil { return err } - h.Nodes = append(h.Nodes, n.ID.String()) + currentHost, err := GetHost(h.ID.String()) + if err != nil { + return err + } + h.HostPass = currentHost.HostPass + h.Nodes = append(currentHost.Nodes, n.ID.String()) return UpsertHost(h) } diff --git a/logic/peers.go b/logic/peers.go index dbc9583c..b2b45b20 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -34,7 +34,6 @@ func GetProxyUpdateForHost(host *models.Host) (models.ProxyManagerPayload, error } else { logger.Log(0, "couldn't find relay host for: ", host.ID.String()) } - } if host.IsRelay { relayedHosts := GetRelayedHosts(host) @@ -142,9 +141,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models if deletedNode != nil { deletedNodes = append(deletedNodes, *deletedNode) } - logger.Log(1, "peer update for host ", host.ID.String()) + logger.Log(1, "peer update for host", host.ID.String()) peerIndexMap := make(map[string]int) for _, nodeID := range host.Nodes { + nodeID := nodeID node, err := GetNodeByID(nodeID) if err != nil { continue @@ -163,7 +163,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models } for _, peer := range currentPeers { peer := peer - if peer.ID == node.ID { + if peer.ID.String() == node.ID.String() { logger.Log(2, "peer update, skipping self") //skip yourself continue @@ -185,7 +185,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models continue } if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) { - log.Println("peer update, skipping node for acl") + logger.Log(2, "peer update, skipping node for acl") //skip if not permitted by acl continue } diff --git a/models/enrollment_key.go b/models/enrollment_key.go new file mode 100644 index 00000000..1cba2ec3 --- /dev/null +++ b/models/enrollment_key.go @@ -0,0 +1,59 @@ +package models + +import ( + "time" +) + +// EnrollmentToken - the tokenized version of an enrollmentkey; +// to be used for host registration +type EnrollmentToken struct { + Server string `json:"server"` + Value string `json:"value"` +} + +// EnrollmentKeyLength - the length of an enrollment key - 62^16 unique possibilities +const EnrollmentKeyLength = 32 + +// EnrollmentKey - the key used to register hosts and join them to specific networks +type EnrollmentKey struct { + Expiration time.Time `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Value string `json:"value"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` + Token string `json:"token,omitempty"` // B64 value of EnrollmentToken +} + +// APIEnrollmentKey - used to create enrollment keys via API +type APIEnrollmentKey struct { + Expiration int64 `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` +} + +// EnrollmentKey.IsValid - checks if the key is still valid to use +func (k *EnrollmentKey) IsValid() bool { + if k == nil { + return false + } + if k.UsesRemaining > 0 { + return true + } + if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) { + return true + } + + return k.Unlimited +} + +// EnrollmentKey.Validate - validate's an EnrollmentKey +// should be used during creation +func (k *EnrollmentKey) Validate() bool { + return k.Networks != nil && + k.Tags != nil && + len(k.Value) == EnrollmentKeyLength && + k.IsValid() +} diff --git a/models/host.go b/models/host.go index 743ce431..86991198 100644 --- a/models/host.go +++ b/models/host.go @@ -74,6 +74,10 @@ const ( DeleteHost = "DELETE_HOST" // JoinHostToNetwork - constant for host network join action JoinHostToNetwork = "JOIN_HOST_TO_NETWORK" + // Acknowledgement - ACK response for hosts + Acknowledgement = "ACK" + // RequestAck - request an ACK + RequestAck = "REQ_ACK" ) // HostUpdate - struct for host update diff --git a/mq/handlers.go b/mq/handlers.go index 94951a77..379fbfc2 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -9,6 +9,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" @@ -144,6 +145,19 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String())) var sendPeerUpdate bool switch hostUpdate.Action { + case models.Acknowledgement: + hu := hostactions.GetAction(currentHost.ID.String()) + if hu != nil { + if err = HostUpdate(hu); err != nil { + logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) + return + } else { + if err = PublishSingleHostPeerUpdate(currentHost, nil); err != nil { + logger.Log(0, "failed peers publish after join acknowledged", hostUpdate.Host.Name, currentHost.ID.String(), err.Error()) + return + } + } + } case models.UpdateHost: sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost) err := logic.UpsertHost(currentHost) @@ -169,6 +183,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { } sendPeerUpdate = true } + if sendPeerUpdate { err := PublishPeerUpdate() if err != nil { diff --git a/mq/publishers.go b/mq/publishers.go index c306df79..aa5391c3 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -61,6 +61,9 @@ func PublishSingleHostPeerUpdate(host *models.Host, deletedNode *models.Node) er if err != nil { return err } + if len(peerUpdate.Peers) == 0 { // no peers to send + return nil + } if host.ProxyEnabled { proxyUpdate, err := logic.GetProxyUpdateForHost(host) if err != nil {