Merge pull request #2050 from gravitl/GRA-1198-enrollment_keys

Gra 1198 enrollment keys
This commit is contained in:
dcarns 2023-02-28 09:26:23 -05:00 committed by GitHub
commit ad4bab064b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 816 additions and 33 deletions

View file

@ -1,7 +1,7 @@
//Environment file for getting variables
//Currently the only thing it does is set the master password
//Should probably have it take over functions from OS such as port and mongodb connection details
//Reads from the config/environments/dev.yaml file by default
// Environment file for getting variables
// Currently the only thing it does is set the master password
// Should probably have it take over functions from OS such as port and mongodb connection details
// Reads from the config/environments/dev.yaml file by default
package config
import (

View file

@ -26,6 +26,7 @@ var HttpHandlers = []interface{}{
ipHandlers,
loggerHandlers,
hostHandlers,
enrollmentKeyHandlers,
}
// HandleRESTRequests - handles the rest requests

View file

@ -0,0 +1,238 @@
package controller
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/servercfg"
)
func enrollmentKeyHandlers(r *mux.Router) {
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost)
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete)
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost)
}
// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
//
// Lists all EnrollmentKeys for admins.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: getEnrollmentKeysSlice
func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
currentKeys, err := logic.GetAllEnrollmentKeys()
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
for i := range currentKeys {
currentKey := currentKeys[i]
if err = logic.Tokenize(currentKey, servercfg.GetAPIHost()); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
}
// return JSON/API formatted keys
logger.Log(2, r.Header.Get("user"), "fetched enrollment keys")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(currentKeys)
}
// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey
//
// Deletes an EnrollmentKey from Netmaker server.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: deleteEnrollmentKeyResponse
func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
keyID := params["keyID"]
err := logic.DeleteEnrollmentKey(keyID)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID)
w.WriteHeader(http.StatusOK)
}
// swagger:route POST /api/v1/enrollment-keys enrollmentKeys createEnrollmentKey
//
// Creates an EnrollmentKey for hosts to use on Netmaker server.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: createEnrollmentKeyResponse
func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
var enrollmentKeyBody models.APIEnrollmentKey
err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
if err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
var newTime time.Time
if enrollmentKeyBody.Expiration > 0 {
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
}
newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
logger.Log(2, r.Header.Get("user"), "created enrollment key")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(newEnrollmentKey)
}
// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
//
// Handles a Netclient registration with server and add nodes accordingly.
//
// Schemes: https
//
// Security:
// oauth
//
// Responses:
// 200: handleHostRegisterResponse
func handleHostRegister(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
token := params["token"]
logger.Log(0, "received registration attempt with token", token)
// check if token exists
enrollmentKey, err := logic.DeTokenize(token)
if err != nil {
logger.Log(0, "invalid enrollment key used", token, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
// get the host
var newHost models.Host
if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
hostExists := false
// check if host already exists
if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks")
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest"))
return
}
// version check
if !logic.IsVersionComptatible(newHost.Version) || newHost.TrafficKeyPublic == nil {
err := fmt.Errorf("incompatible netclient")
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
key, keyErr := logic.RetrievePublicTrafficKey()
if keyErr != nil {
logger.Log(0, "error retrieving key:", keyErr.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
// use the token
if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration")
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"))
return
}
if !hostExists {
// register host
logic.CheckHostPorts(&newHost)
if err = logic.CreateHost(&newHost); err != nil {
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
} else {
// need to revise the list of networks from key
// based on the ones host currently has
var networksToAdd = []string{}
currentNets := logic.GetHostNetworks(newHost.ID.String())
for _, newNet := range enrollmentKey.Networks {
if !logic.StringSliceContains(currentNets, newNet) {
networksToAdd = append(networksToAdd, newNet)
}
}
enrollmentKey.Networks = networksToAdd
}
// ready the response
server := servercfg.GetServerInfo()
server.TrafficKey = key
logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&server)
// notify host of changes, peer and node updates
go checkNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
}
// run through networks and send a host update
func checkNetRegAndHostUpdate(networks []string, h *models.Host) {
// publish host update through MQ
for i := range networks {
network := networks[i]
if ok, _ := logic.NetworkExists(network); ok {
newNode, err := logic.UpdateHostNetwork(h, network, true)
if err != nil {
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
continue
}
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
hostactions.AddAction(models.HostUpdate{
Action: models.JoinHostToNetwork,
Host: *h,
Node: *newNode,
})
}
}
if servercfg.IsMessageQueueBackend() {
mq.HostUpdate(&models.HostUpdate{
Action: models.RequestAck,
Host: *h,
})
if err := mq.PublishPeerUpdate(); err != nil {
logger.Log(0, "failed to publish peer update during registration -", err.Error())
}
}
}

View file

@ -10,8 +10,10 @@ import (
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
)
@ -230,18 +232,17 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
return
}
logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name)
if err = mq.HostUpdate(&models.HostUpdate{
hostactions.AddAction(models.HostUpdate{
Action: models.JoinHostToNetwork,
Host: *currHost,
Node: *newNode,
}); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update host to join network:", hostid, network, err.Error())
})
if servercfg.IsMessageQueueBackend() {
mq.HostUpdate(&models.HostUpdate{
Action: models.RequestAck,
Host: *currHost,
})
}
go func() { // notify of peer change
if err := mq.PublishPeerUpdate(); err != nil {
logger.Log(1, "error publishing peer update ", err.Error())
}
}()
logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network))
w.WriteHeader(http.StatusOK)

View file

@ -57,6 +57,8 @@ const (
CACHE_TABLE_NAME = "cache"
// HOSTS_TABLE_NAME - the table name for hosts
HOSTS_TABLE_NAME = "hosts"
// ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys
ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys"
// == ERROR CONSTS ==
// NO_RECORD - no singular result found
@ -138,6 +140,7 @@ func createTables() {
createTable(USER_GROUPS_TABLE_NAME)
createTable(CACHE_TABLE_NAME)
createTable(HOSTS_TABLE_NAME)
createTable(ENROLLMENT_KEYS_TABLE_NAME)
}
func createTable(tableName string) error {

View file

@ -5,7 +5,6 @@ import "testing"
func Test_genKeyName(t *testing.T) {
for i := 0; i < 100; i++ {
kname := genKeyName()
t.Log(kname)
if len(kname) != 20 {
t.Fatalf("improper length of key name, expected 20 got :%d", len(kname))
}
@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) {
func Test_genKey(t *testing.T) {
for i := 0; i < 100; i++ {
kname := GenKey()
t.Log(kname)
if len(kname) != 16 {
t.Fatalf("improper length of key name, expected 16 got :%d", len(kname))
}

221
logic/enrollmentkey.go Normal file
View file

@ -0,0 +1,221 @@
package logic
import (
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
)
// EnrollmentErrors - struct for holding EnrollmentKey error messages
var EnrollmentErrors = struct {
InvalidCreate error
NoKeyFound error
InvalidKey error
NoUsesRemaining error
FailedToTokenize error
FailedToDeTokenize error
}{
InvalidCreate: fmt.Errorf("invalid enrollment key created"),
NoKeyFound: fmt.Errorf("no enrollmentkey found"),
InvalidKey: fmt.Errorf("invalid key provided"),
NoUsesRemaining: fmt.Errorf("no uses remaining"),
FailedToTokenize: fmt.Errorf("failed to tokenize"),
FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
}
// CreateEnrollmentKey - creates a new enrollment key in db
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
newKeyID, err := getUniqueEnrollmentID()
if err != nil {
return nil, err
}
k = &models.EnrollmentKey{
Value: newKeyID,
Expiration: time.Time{},
UsesRemaining: 0,
Unlimited: unlimited,
Networks: []string{},
Tags: []string{},
}
if uses > 0 {
k.UsesRemaining = uses
}
if !expiration.IsZero() {
k.Expiration = expiration
}
if len(networks) > 0 {
k.Networks = networks
}
if len(tags) > 0 {
k.Tags = tags
}
if ok := k.Validate(); !ok {
return nil, EnrollmentErrors.InvalidCreate
}
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
}
return
}
// GetAllEnrollmentKeys - fetches all enrollment keys from DB
func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return nil, err
}
var currentKeysList = []*models.EnrollmentKey{}
for k := range currentKeys {
currentKeysList = append(currentKeysList, currentKeys[k])
}
return currentKeysList, nil
}
// GetEnrollmentKey - fetches a single enrollment key
// returns nil and error if not found
func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return nil, err
}
if key, ok := currentKeys[value]; ok {
return key, nil
}
return nil, EnrollmentErrors.NoKeyFound
}
// DeleteEnrollmentKey - delete's a given enrollment key by value
func DeleteEnrollmentKey(value string) error {
_, err := GetEnrollmentKey(value)
if err != nil {
return err
}
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
}
// TryToUseEnrollmentKey - checks first if key can be decremented
// returns true if it is decremented or isvalid
func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
key, err := decrementEnrollmentKey(k.Value)
if err != nil {
if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
return k.IsValid()
}
} else {
k.UsesRemaining = key.UsesRemaining
return true
}
return false
}
// Tokenize - tokenizes an enrollment key to be used via registration
// and attaches it to the Token field on the struct
func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
if len(serverAddr) == 0 || k == nil {
return EnrollmentErrors.FailedToTokenize
}
newToken := models.EnrollmentToken{
Server: serverAddr,
Value: k.Value,
}
data, err := json.Marshal(&newToken)
if err != nil {
return err
}
k.Token = b64.StdEncoding.EncodeToString(data)
return nil
}
// DeTokenize - detokenizes a base64 encoded string
// and finds the associated enrollment key
func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
if len(b64Token) == 0 {
return nil, EnrollmentErrors.FailedToDeTokenize
}
tokenData, err := b64.StdEncoding.DecodeString(b64Token)
if err != nil {
return nil, err
}
var newToken models.EnrollmentToken
err = json.Unmarshal(tokenData, &newToken)
if err != nil {
return nil, err
}
k, err := GetEnrollmentKey(newToken.Value)
if err != nil {
return nil, err
}
return k, nil
}
// == private ==
// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
k, err := GetEnrollmentKey(value)
if err != nil {
return nil, err
}
if k.UsesRemaining == 0 {
return nil, EnrollmentErrors.NoUsesRemaining
}
k.UsesRemaining = k.UsesRemaining - 1
if err = upsertEnrollmentKey(k); err != nil {
return nil, err
}
return k, nil
}
func upsertEnrollmentKey(k *models.EnrollmentKey) error {
if k == nil {
return EnrollmentErrors.InvalidKey
}
data, err := json.Marshal(k)
if err != nil {
return err
}
return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
}
func getUniqueEnrollmentID() (string, error) {
currentKeys, err := getEnrollmentKeysMap()
if err != nil {
return "", err
}
newID := ncutils.MakeRandomString(models.EnrollmentKeyLength)
for _, ok := currentKeys[newID]; ok; {
newID = ncutils.MakeRandomString(models.EnrollmentKeyLength)
}
return newID, nil
}
func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
if err != nil {
if !database.IsEmptyRecord(err) {
return nil, err
}
}
if records == nil {
records = make(map[string]string)
}
currentKeys := make(map[string]*models.EnrollmentKey, 0)
if len(records) > 0 {
for k := range records {
var currentKey models.EnrollmentKey
if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
continue
}
currentKeys[k] = &currentKey
}
}
return currentKeys, nil
}

206
logic/enrollmentkey_test.go Normal file
View file

@ -0,0 +1,206 @@
package logic
import (
"testing"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/stretchr/testify/assert"
)
func TestCreateEnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
t.Run("Can_Not_Create_Key", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
assert.Nil(t, newKey)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
})
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
assert.Nil(t, err)
assert.Equal(t, 1, newKey.UsesRemaining)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Time", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
})
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Networks) == 2)
})
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
assert.Nil(t, err)
assert.True(t, newKey.IsValid())
assert.True(t, len(newKey.Tags) == 2)
})
t.Run("Can_Get_List_of_Keys", func(t *testing.T) {
keys, err := GetAllEnrollmentKeys()
assert.Nil(t, err)
assert.True(t, len(keys) > 0)
for i := range keys {
assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength)
}
})
removeAllEnrollments()
}
func TestDelete_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
t.Run("Can_Delete_Key", func(t *testing.T) {
assert.True(t, newKey.IsValid())
err := DeleteEnrollmentKey(newKey.Value)
assert.Nil(t, err)
oldKey, err := GetEnrollmentKey(newKey.Value)
assert.Nil(t, oldKey)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
})
t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
err := DeleteEnrollmentKey("notakey")
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
})
removeAllEnrollments()
}
func TestDecrement_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
t.Run("Check_initial_uses", func(t *testing.T) {
assert.True(t, newKey.IsValid())
assert.Equal(t, newKey.UsesRemaining, 1)
})
t.Run("Check can decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 1)
k, err := decrementEnrollmentKey(newKey.Value)
assert.Nil(t, err)
newKey = k
})
t.Run("Check can not decrement", func(t *testing.T) {
assert.Equal(t, newKey.UsesRemaining, 0)
_, err := decrementEnrollmentKey(newKey.Value)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining)
})
removeAllEnrollments()
}
func TestUsability_EnrollmentKey(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
t.Run("Check if valid use key can be used", func(t *testing.T) {
assert.Equal(t, key1.UsesRemaining, 1)
ok := TryToUseEnrollmentKey(key1)
assert.True(t, ok)
assert.Equal(t, 0, key1.UsesRemaining)
})
t.Run("Check if valid time key can be used", func(t *testing.T) {
assert.True(t, !key2.Expiration.IsZero())
ok := TryToUseEnrollmentKey(key2)
assert.True(t, ok)
})
t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
assert.True(t, key3.Unlimited)
ok := TryToUseEnrollmentKey(key3)
assert.True(t, ok)
})
t.Run("check invalid key can not be used", func(t *testing.T) {
ok := TryToUseEnrollmentKey(key1)
assert.False(t, ok)
})
}
func removeAllEnrollments() {
database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
}
//Test that cheks if it can tokenize
//Test that cheks if it can't tokenize
func TestTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"
t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) {
err := Tokenize(nil, "ServerAddress")
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
})
t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) {
err := Tokenize(newKey, "")
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
})
t.Run("Can_Tokenize", func(t *testing.T) {
err := Tokenize(newKey, serverAddr)
assert.Nil(t, err)
assert.True(t, len(newKey.Token) > 0)
})
t.Run("Is_Correct_B64_Token", func(t *testing.T) {
newKey.Value = defaultValue
err := Tokenize(newKey, serverAddr)
assert.Nil(t, err)
assert.Equal(t, newKey.Token, b64value)
})
removeAllEnrollments()
}
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
const serverAddr = "api.myserver.com"
t.Run("Can_Not_DeTokenize", func(t *testing.T) {
value, err := DeTokenize("")
assert.Nil(t, value)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize)
})
t.Run("Can_Not_Find_Key", func(t *testing.T) {
value, err := DeTokenize(b64Value)
assert.Nil(t, value)
assert.NotNil(t, err)
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
})
t.Run("Can_DeTokenize", func(t *testing.T) {
err := Tokenize(newKey, serverAddr)
assert.Nil(t, err)
output, err := DeTokenize(newKey.Token)
assert.Nil(t, err)
assert.NotNil(t, output)
assert.Equal(t, newKey.Value, output.Value)
})
removeAllEnrollments()
}

View file

@ -2,37 +2,28 @@ package logic
import (
"context"
"fmt"
"net"
"testing"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/matryer/is"
)
func TestMain(m *testing.M) {
func TestCheckPorts(t *testing.T) {
database.InitializeDatabase()
defer database.CloseDB()
CreateAdmin(&models.User{
UserName: "admin",
Password: "password",
IsAdmin: true,
Networks: []string{},
Groups: []string{},
})
peerUpdate := make(chan *models.Node)
go ManageZombies(context.Background(), peerUpdate)
go func() {
for update := range peerUpdate {
for y := range peerUpdate {
fmt.Printf("Pointless %v\n", y)
//do nothing
logger.Log(3, "received node update", update.Action)
}
}()
}
func TestCheckPorts(t *testing.T) {
h := models.Host{
ID: uuid.New(),
EndpointIP: net.ParseIP("192.168.1.1"),

View file

@ -0,0 +1,38 @@
package hostactions
import (
"sync"
"github.com/gravitl/netmaker/models"
)
// nodeActionHandler - handles the storage of host action updates
var nodeActionHandler sync.Map
// AddAction - adds a host action to a host's list to be retrieved from broker update
func AddAction(hu models.HostUpdate) {
currentRecords, ok := nodeActionHandler.Load(hu.Host.ID.String())
if !ok { // no list exists yet
nodeActionHandler.Store(hu.Host.ID.String(), []models.HostUpdate{hu})
} else { // list exists, append to it
currentList := currentRecords.([]models.HostUpdate)
currentList = append(currentList, hu)
nodeActionHandler.Store(hu.Host.ID.String(), currentList)
}
}
// GetAction - gets an action if exists
// TODO: may need to move to DB rather than sync map for HA
func GetAction(id string) *models.HostUpdate {
currentRecords, ok := nodeActionHandler.Load(id)
if !ok {
return nil
}
currentList := currentRecords.([]models.HostUpdate)
if len(currentList) > 0 {
hu := currentList[0]
nodeActionHandler.Store(hu.Host.ID.String(), currentList[1:])
return &hu
}
return nil
}

View file

@ -90,7 +90,7 @@ func CreateHost(h *models.Host) error {
if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
return ErrHostExists
}
//encrypt that password so we never see it
// encrypt that password so we never see it
hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5)
if err != nil {
return err
@ -236,7 +236,12 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
if err != nil {
return err
}
h.Nodes = append(h.Nodes, n.ID.String())
currentHost, err := GetHost(h.ID.String())
if err != nil {
return err
}
h.HostPass = currentHost.HostPass
h.Nodes = append(currentHost.Nodes, n.ID.String())
return UpsertHost(h)
}

View file

@ -34,7 +34,6 @@ func GetProxyUpdateForHost(host *models.Host) (models.ProxyManagerPayload, error
} else {
logger.Log(0, "couldn't find relay host for: ", host.ID.String())
}
}
if host.IsRelay {
relayedHosts := GetRelayedHosts(host)
@ -142,9 +141,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
if deletedNode != nil {
deletedNodes = append(deletedNodes, *deletedNode)
}
logger.Log(1, "peer update for host ", host.ID.String())
logger.Log(1, "peer update for host", host.ID.String())
peerIndexMap := make(map[string]int)
for _, nodeID := range host.Nodes {
nodeID := nodeID
node, err := GetNodeByID(nodeID)
if err != nil {
continue
@ -163,7 +163,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
}
for _, peer := range currentPeers {
peer := peer
if peer.ID == node.ID {
if peer.ID.String() == node.ID.String() {
logger.Log(2, "peer update, skipping self")
//skip yourself
continue
@ -185,7 +185,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
continue
}
if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) {
log.Println("peer update, skipping node for acl")
logger.Log(2, "peer update, skipping node for acl")
//skip if not permitted by acl
continue
}

59
models/enrollment_key.go Normal file
View file

@ -0,0 +1,59 @@
package models
import (
"time"
)
// EnrollmentToken - the tokenized version of an enrollmentkey;
// to be used for host registration
type EnrollmentToken struct {
Server string `json:"server"`
Value string `json:"value"`
}
// EnrollmentKeyLength - the length of an enrollment key - 62^16 unique possibilities
const EnrollmentKeyLength = 32
// EnrollmentKey - the key used to register hosts and join them to specific networks
type EnrollmentKey struct {
Expiration time.Time `json:"expiration"`
UsesRemaining int `json:"uses_remaining"`
Value string `json:"value"`
Networks []string `json:"networks"`
Unlimited bool `json:"unlimited"`
Tags []string `json:"tags"`
Token string `json:"token,omitempty"` // B64 value of EnrollmentToken
}
// APIEnrollmentKey - used to create enrollment keys via API
type APIEnrollmentKey struct {
Expiration int64 `json:"expiration"`
UsesRemaining int `json:"uses_remaining"`
Networks []string `json:"networks"`
Unlimited bool `json:"unlimited"`
Tags []string `json:"tags"`
}
// EnrollmentKey.IsValid - checks if the key is still valid to use
func (k *EnrollmentKey) IsValid() bool {
if k == nil {
return false
}
if k.UsesRemaining > 0 {
return true
}
if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) {
return true
}
return k.Unlimited
}
// EnrollmentKey.Validate - validate's an EnrollmentKey
// should be used during creation
func (k *EnrollmentKey) Validate() bool {
return k.Networks != nil &&
k.Tags != nil &&
len(k.Value) == EnrollmentKeyLength &&
k.IsValid()
}

View file

@ -74,6 +74,10 @@ const (
DeleteHost = "DELETE_HOST"
// JoinHostToNetwork - constant for host network join action
JoinHostToNetwork = "JOIN_HOST_TO_NETWORK"
// Acknowledgement - ACK response for hosts
Acknowledgement = "ACK"
// RequestAck - request an ACK
RequestAck = "REQ_ACK"
)
// HostUpdate - struct for host update

View file

@ -9,6 +9,7 @@ import (
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg"
@ -144,6 +145,19 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String()))
var sendPeerUpdate bool
switch hostUpdate.Action {
case models.Acknowledgement:
hu := hostactions.GetAction(currentHost.ID.String())
if hu != nil {
if err = HostUpdate(hu); err != nil {
logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
return
} else {
if err = PublishSingleHostPeerUpdate(currentHost, nil); err != nil {
logger.Log(0, "failed peers publish after join acknowledged", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
return
}
}
}
case models.UpdateHost:
sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost)
err := logic.UpsertHost(currentHost)
@ -169,6 +183,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
}
sendPeerUpdate = true
}
if sendPeerUpdate {
err := PublishPeerUpdate()
if err != nil {

View file

@ -61,6 +61,9 @@ func PublishSingleHostPeerUpdate(host *models.Host, deletedNode *models.Node) er
if err != nil {
return err
}
if len(peerUpdate.Peers) == 0 { // no peers to send
return nil
}
if host.ProxyEnabled {
proxyUpdate, err := logic.GetProxyUpdateForHost(host)
if err != nil {