mirror of
https://github.com/gravitl/netmaker.git
synced 2025-10-03 10:24:24 +08:00
Merge pull request #2050 from gravitl/GRA-1198-enrollment_keys
Gra 1198 enrollment keys
This commit is contained in:
commit
ad4bab064b
16 changed files with 816 additions and 33 deletions
|
@ -1,7 +1,7 @@
|
|||
//Environment file for getting variables
|
||||
//Currently the only thing it does is set the master password
|
||||
//Should probably have it take over functions from OS such as port and mongodb connection details
|
||||
//Reads from the config/environments/dev.yaml file by default
|
||||
// Environment file for getting variables
|
||||
// Currently the only thing it does is set the master password
|
||||
// Should probably have it take over functions from OS such as port and mongodb connection details
|
||||
// Reads from the config/environments/dev.yaml file by default
|
||||
package config
|
||||
|
||||
import (
|
||||
|
|
|
@ -26,6 +26,7 @@ var HttpHandlers = []interface{}{
|
|||
ipHandlers,
|
||||
loggerHandlers,
|
||||
hostHandlers,
|
||||
enrollmentKeyHandlers,
|
||||
}
|
||||
|
||||
// HandleRESTRequests - handles the rest requests
|
||||
|
|
238
controllers/enrollmentkeys.go
Normal file
238
controllers/enrollmentkeys.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/hostactions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
func enrollmentKeyHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost)
|
||||
}
|
||||
|
||||
// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
|
||||
//
|
||||
// Lists all EnrollmentKeys for admins.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
//
|
||||
// Responses:
|
||||
// 200: getEnrollmentKeysSlice
|
||||
func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
|
||||
currentKeys, err := logic.GetAllEnrollmentKeys()
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
for i := range currentKeys {
|
||||
currentKey := currentKeys[i]
|
||||
if err = logic.Tokenize(currentKey, servercfg.GetAPIHost()); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
}
|
||||
// return JSON/API formatted keys
|
||||
logger.Log(2, r.Header.Get("user"), "fetched enrollment keys")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(currentKeys)
|
||||
}
|
||||
|
||||
// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey
|
||||
//
|
||||
// Deletes an EnrollmentKey from Netmaker server.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
//
|
||||
// Responses:
|
||||
// 200: deleteEnrollmentKeyResponse
|
||||
func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
keyID := params["keyID"]
|
||||
err := logic.DeleteEnrollmentKey(keyID)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// swagger:route POST /api/v1/enrollment-keys enrollmentKeys createEnrollmentKey
|
||||
//
|
||||
// Creates an EnrollmentKey for hosts to use on Netmaker server.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
//
|
||||
// Responses:
|
||||
// 200: createEnrollmentKeyResponse
|
||||
func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var enrollmentKeyBody models.APIEnrollmentKey
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
var newTime time.Time
|
||||
if enrollmentKeyBody.Expiration > 0 {
|
||||
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
|
||||
}
|
||||
|
||||
newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logger.Log(2, r.Header.Get("user"), "created enrollment key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(newEnrollmentKey)
|
||||
}
|
||||
|
||||
// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
|
||||
//
|
||||
// Handles a Netclient registration with server and add nodes accordingly.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
//
|
||||
// Responses:
|
||||
// 200: handleHostRegisterResponse
|
||||
func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
token := params["token"]
|
||||
logger.Log(0, "received registration attempt with token", token)
|
||||
// check if token exists
|
||||
enrollmentKey, err := logic.DeTokenize(token)
|
||||
if err != nil {
|
||||
logger.Log(0, "invalid enrollment key used", token, err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
// get the host
|
||||
var newHost models.Host
|
||||
if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
hostExists := false
|
||||
// check if host already exists
|
||||
if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 {
|
||||
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks")
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest"))
|
||||
return
|
||||
}
|
||||
// version check
|
||||
if !logic.IsVersionComptatible(newHost.Version) || newHost.TrafficKeyPublic == nil {
|
||||
err := fmt.Errorf("incompatible netclient")
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
key, keyErr := logic.RetrievePublicTrafficKey()
|
||||
if keyErr != nil {
|
||||
logger.Log(0, "error retrieving key:", keyErr.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
// use the token
|
||||
if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok {
|
||||
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration")
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"))
|
||||
return
|
||||
}
|
||||
if !hostExists {
|
||||
// register host
|
||||
logic.CheckHostPorts(&newHost)
|
||||
if err = logic.CreateHost(&newHost); err != nil {
|
||||
logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// need to revise the list of networks from key
|
||||
// based on the ones host currently has
|
||||
var networksToAdd = []string{}
|
||||
currentNets := logic.GetHostNetworks(newHost.ID.String())
|
||||
for _, newNet := range enrollmentKey.Networks {
|
||||
if !logic.StringSliceContains(currentNets, newNet) {
|
||||
networksToAdd = append(networksToAdd, newNet)
|
||||
}
|
||||
}
|
||||
enrollmentKey.Networks = networksToAdd
|
||||
}
|
||||
// ready the response
|
||||
server := servercfg.GetServerInfo()
|
||||
server.TrafficKey = key
|
||||
logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(&server)
|
||||
// notify host of changes, peer and node updates
|
||||
go checkNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
|
||||
}
|
||||
|
||||
// run through networks and send a host update
|
||||
func checkNetRegAndHostUpdate(networks []string, h *models.Host) {
|
||||
// publish host update through MQ
|
||||
for i := range networks {
|
||||
network := networks[i]
|
||||
if ok, _ := logic.NetworkExists(network); ok {
|
||||
newNode, err := logic.UpdateHostNetwork(h, network, true)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
|
||||
continue
|
||||
}
|
||||
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
|
||||
hostactions.AddAction(models.HostUpdate{
|
||||
Action: models.JoinHostToNetwork,
|
||||
Host: *h,
|
||||
Node: *newNode,
|
||||
})
|
||||
}
|
||||
}
|
||||
if servercfg.IsMessageQueueBackend() {
|
||||
mq.HostUpdate(&models.HostUpdate{
|
||||
Action: models.RequestAck,
|
||||
Host: *h,
|
||||
})
|
||||
if err := mq.PublishPeerUpdate(); err != nil {
|
||||
logger.Log(0, "failed to publish peer update during registration -", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,8 +10,10 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/hostactions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
|
@ -230,18 +232,17 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name)
|
||||
if err = mq.HostUpdate(&models.HostUpdate{
|
||||
hostactions.AddAction(models.HostUpdate{
|
||||
Action: models.JoinHostToNetwork,
|
||||
Host: *currHost,
|
||||
Node: *newNode,
|
||||
}); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to update host to join network:", hostid, network, err.Error())
|
||||
})
|
||||
if servercfg.IsMessageQueueBackend() {
|
||||
mq.HostUpdate(&models.HostUpdate{
|
||||
Action: models.RequestAck,
|
||||
Host: *currHost,
|
||||
})
|
||||
}
|
||||
go func() { // notify of peer change
|
||||
if err := mq.PublishPeerUpdate(); err != nil {
|
||||
logger.Log(1, "error publishing peer update ", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
|
@ -57,6 +57,8 @@ const (
|
|||
CACHE_TABLE_NAME = "cache"
|
||||
// HOSTS_TABLE_NAME - the table name for hosts
|
||||
HOSTS_TABLE_NAME = "hosts"
|
||||
// ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys
|
||||
ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys"
|
||||
|
||||
// == ERROR CONSTS ==
|
||||
// NO_RECORD - no singular result found
|
||||
|
@ -138,6 +140,7 @@ func createTables() {
|
|||
createTable(USER_GROUPS_TABLE_NAME)
|
||||
createTable(CACHE_TABLE_NAME)
|
||||
createTable(HOSTS_TABLE_NAME)
|
||||
createTable(ENROLLMENT_KEYS_TABLE_NAME)
|
||||
}
|
||||
|
||||
func createTable(tableName string) error {
|
||||
|
|
|
@ -5,7 +5,6 @@ import "testing"
|
|||
func Test_genKeyName(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
kname := genKeyName()
|
||||
t.Log(kname)
|
||||
if len(kname) != 20 {
|
||||
t.Fatalf("improper length of key name, expected 20 got :%d", len(kname))
|
||||
}
|
||||
|
@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) {
|
|||
func Test_genKey(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
kname := GenKey()
|
||||
t.Log(kname)
|
||||
if len(kname) != 16 {
|
||||
t.Fatalf("improper length of key name, expected 16 got :%d", len(kname))
|
||||
}
|
||||
|
|
221
logic/enrollmentkey.go
Normal file
221
logic/enrollmentkey.go
Normal file
|
@ -0,0 +1,221 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
)
|
||||
|
||||
// EnrollmentErrors - struct for holding EnrollmentKey error messages
|
||||
var EnrollmentErrors = struct {
|
||||
InvalidCreate error
|
||||
NoKeyFound error
|
||||
InvalidKey error
|
||||
NoUsesRemaining error
|
||||
FailedToTokenize error
|
||||
FailedToDeTokenize error
|
||||
}{
|
||||
InvalidCreate: fmt.Errorf("invalid enrollment key created"),
|
||||
NoKeyFound: fmt.Errorf("no enrollmentkey found"),
|
||||
InvalidKey: fmt.Errorf("invalid key provided"),
|
||||
NoUsesRemaining: fmt.Errorf("no uses remaining"),
|
||||
FailedToTokenize: fmt.Errorf("failed to tokenize"),
|
||||
FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
|
||||
}
|
||||
|
||||
// CreateEnrollmentKey - creates a new enrollment key in db
|
||||
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
|
||||
newKeyID, err := getUniqueEnrollmentID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k = &models.EnrollmentKey{
|
||||
Value: newKeyID,
|
||||
Expiration: time.Time{},
|
||||
UsesRemaining: 0,
|
||||
Unlimited: unlimited,
|
||||
Networks: []string{},
|
||||
Tags: []string{},
|
||||
}
|
||||
if uses > 0 {
|
||||
k.UsesRemaining = uses
|
||||
}
|
||||
if !expiration.IsZero() {
|
||||
k.Expiration = expiration
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
k.Networks = networks
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
k.Tags = tags
|
||||
}
|
||||
if ok := k.Validate(); !ok {
|
||||
return nil, EnrollmentErrors.InvalidCreate
|
||||
}
|
||||
if err = upsertEnrollmentKey(k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetAllEnrollmentKeys - fetches all enrollment keys from DB
|
||||
func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
|
||||
currentKeys, err := getEnrollmentKeysMap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var currentKeysList = []*models.EnrollmentKey{}
|
||||
for k := range currentKeys {
|
||||
currentKeysList = append(currentKeysList, currentKeys[k])
|
||||
}
|
||||
return currentKeysList, nil
|
||||
}
|
||||
|
||||
// GetEnrollmentKey - fetches a single enrollment key
|
||||
// returns nil and error if not found
|
||||
func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
|
||||
currentKeys, err := getEnrollmentKeysMap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key, ok := currentKeys[value]; ok {
|
||||
return key, nil
|
||||
}
|
||||
return nil, EnrollmentErrors.NoKeyFound
|
||||
}
|
||||
|
||||
// DeleteEnrollmentKey - delete's a given enrollment key by value
|
||||
func DeleteEnrollmentKey(value string) error {
|
||||
_, err := GetEnrollmentKey(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
|
||||
}
|
||||
|
||||
// TryToUseEnrollmentKey - checks first if key can be decremented
|
||||
// returns true if it is decremented or isvalid
|
||||
func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
|
||||
key, err := decrementEnrollmentKey(k.Value)
|
||||
if err != nil {
|
||||
if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
|
||||
return k.IsValid()
|
||||
}
|
||||
} else {
|
||||
k.UsesRemaining = key.UsesRemaining
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Tokenize - tokenizes an enrollment key to be used via registration
|
||||
// and attaches it to the Token field on the struct
|
||||
func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
|
||||
if len(serverAddr) == 0 || k == nil {
|
||||
return EnrollmentErrors.FailedToTokenize
|
||||
}
|
||||
newToken := models.EnrollmentToken{
|
||||
Server: serverAddr,
|
||||
Value: k.Value,
|
||||
}
|
||||
data, err := json.Marshal(&newToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.Token = b64.StdEncoding.EncodeToString(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeTokenize - detokenizes a base64 encoded string
|
||||
// and finds the associated enrollment key
|
||||
func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
|
||||
if len(b64Token) == 0 {
|
||||
return nil, EnrollmentErrors.FailedToDeTokenize
|
||||
}
|
||||
tokenData, err := b64.StdEncoding.DecodeString(b64Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newToken models.EnrollmentToken
|
||||
err = json.Unmarshal(tokenData, &newToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k, err := GetEnrollmentKey(newToken.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// == private ==
|
||||
|
||||
// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
|
||||
func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
|
||||
k, err := GetEnrollmentKey(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if k.UsesRemaining == 0 {
|
||||
return nil, EnrollmentErrors.NoUsesRemaining
|
||||
}
|
||||
k.UsesRemaining = k.UsesRemaining - 1
|
||||
if err = upsertEnrollmentKey(k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return k, nil
|
||||
}
|
||||
|
||||
func upsertEnrollmentKey(k *models.EnrollmentKey) error {
|
||||
if k == nil {
|
||||
return EnrollmentErrors.InvalidKey
|
||||
}
|
||||
data, err := json.Marshal(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
|
||||
}
|
||||
|
||||
func getUniqueEnrollmentID() (string, error) {
|
||||
currentKeys, err := getEnrollmentKeysMap()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newID := ncutils.MakeRandomString(models.EnrollmentKeyLength)
|
||||
for _, ok := currentKeys[newID]; ok; {
|
||||
newID = ncutils.MakeRandomString(models.EnrollmentKeyLength)
|
||||
}
|
||||
return newID, nil
|
||||
}
|
||||
|
||||
func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
|
||||
records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
|
||||
if err != nil {
|
||||
if !database.IsEmptyRecord(err) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if records == nil {
|
||||
records = make(map[string]string)
|
||||
}
|
||||
currentKeys := make(map[string]*models.EnrollmentKey, 0)
|
||||
if len(records) > 0 {
|
||||
for k := range records {
|
||||
var currentKey models.EnrollmentKey
|
||||
if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil {
|
||||
continue
|
||||
}
|
||||
currentKeys[k] = ¤tKey
|
||||
}
|
||||
}
|
||||
return currentKeys, nil
|
||||
}
|
206
logic/enrollmentkey_test.go
Normal file
206
logic/enrollmentkey_test.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateEnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
t.Run("Can_Not_Create_Key", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
|
||||
assert.Nil(t, newKey)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
|
||||
})
|
||||
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, newKey.UsesRemaining)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Time", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Networks) == 2)
|
||||
})
|
||||
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Tags) == 2)
|
||||
})
|
||||
|
||||
t.Run("Can_Get_List_of_Keys", func(t *testing.T) {
|
||||
keys, err := GetAllEnrollmentKeys()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(keys) > 0)
|
||||
for i := range keys {
|
||||
assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength)
|
||||
}
|
||||
})
|
||||
removeAllEnrollments()
|
||||
}
|
||||
|
||||
func TestDelete_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
t.Run("Can_Delete_Key", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
err := DeleteEnrollmentKey(newKey.Value)
|
||||
assert.Nil(t, err)
|
||||
oldKey, err := GetEnrollmentKey(newKey.Value)
|
||||
assert.Nil(t, oldKey)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
|
||||
})
|
||||
t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
|
||||
err := DeleteEnrollmentKey("notakey")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
|
||||
})
|
||||
removeAllEnrollments()
|
||||
}
|
||||
|
||||
func TestDecrement_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
t.Run("Check_initial_uses", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.Equal(t, newKey.UsesRemaining, 1)
|
||||
})
|
||||
t.Run("Check can decrement", func(t *testing.T) {
|
||||
assert.Equal(t, newKey.UsesRemaining, 1)
|
||||
k, err := decrementEnrollmentKey(newKey.Value)
|
||||
assert.Nil(t, err)
|
||||
newKey = k
|
||||
})
|
||||
t.Run("Check can not decrement", func(t *testing.T) {
|
||||
assert.Equal(t, newKey.UsesRemaining, 0)
|
||||
_, err := decrementEnrollmentKey(newKey.Value)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining)
|
||||
})
|
||||
|
||||
removeAllEnrollments()
|
||||
}
|
||||
|
||||
func TestUsability_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
|
||||
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
|
||||
t.Run("Check if valid use key can be used", func(t *testing.T) {
|
||||
assert.Equal(t, key1.UsesRemaining, 1)
|
||||
ok := TryToUseEnrollmentKey(key1)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 0, key1.UsesRemaining)
|
||||
})
|
||||
|
||||
t.Run("Check if valid time key can be used", func(t *testing.T) {
|
||||
assert.True(t, !key2.Expiration.IsZero())
|
||||
ok := TryToUseEnrollmentKey(key2)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
|
||||
assert.True(t, key3.Unlimited)
|
||||
ok := TryToUseEnrollmentKey(key3)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("check invalid key can not be used", func(t *testing.T) {
|
||||
ok := TryToUseEnrollmentKey(key1)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func removeAllEnrollments() {
|
||||
database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
|
||||
}
|
||||
|
||||
//Test that cheks if it can tokenize
|
||||
//Test that cheks if it can't tokenize
|
||||
|
||||
func TestTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
|
||||
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) {
|
||||
err := Tokenize(nil, "ServerAddress")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
|
||||
})
|
||||
t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) {
|
||||
err := Tokenize(newKey, "")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
|
||||
})
|
||||
|
||||
t.Run("Can_Tokenize", func(t *testing.T) {
|
||||
err := Tokenize(newKey, serverAddr)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(newKey.Token) > 0)
|
||||
})
|
||||
|
||||
t.Run("Is_Correct_B64_Token", func(t *testing.T) {
|
||||
newKey.Value = defaultValue
|
||||
err := Tokenize(newKey, serverAddr)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, newKey.Token, b64value)
|
||||
})
|
||||
removeAllEnrollments()
|
||||
}
|
||||
|
||||
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
|
||||
t.Run("Can_Not_DeTokenize", func(t *testing.T) {
|
||||
value, err := DeTokenize("")
|
||||
assert.Nil(t, value)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize)
|
||||
})
|
||||
t.Run("Can_Not_Find_Key", func(t *testing.T) {
|
||||
value, err := DeTokenize(b64Value)
|
||||
assert.Nil(t, value)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
|
||||
})
|
||||
t.Run("Can_DeTokenize", func(t *testing.T) {
|
||||
err := Tokenize(newKey, serverAddr)
|
||||
assert.Nil(t, err)
|
||||
output, err := DeTokenize(newKey.Token)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, output)
|
||||
assert.Equal(t, newKey.Value, output.Value)
|
||||
})
|
||||
|
||||
removeAllEnrollments()
|
||||
}
|
|
@ -2,37 +2,28 @@ package logic
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/matryer/is"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
func TestCheckPorts(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
CreateAdmin(&models.User{
|
||||
UserName: "admin",
|
||||
Password: "password",
|
||||
IsAdmin: true,
|
||||
Networks: []string{},
|
||||
Groups: []string{},
|
||||
})
|
||||
peerUpdate := make(chan *models.Node)
|
||||
go ManageZombies(context.Background(), peerUpdate)
|
||||
go func() {
|
||||
for update := range peerUpdate {
|
||||
for y := range peerUpdate {
|
||||
fmt.Printf("Pointless %v\n", y)
|
||||
//do nothing
|
||||
logger.Log(3, "received node update", update.Action)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestCheckPorts(t *testing.T) {
|
||||
h := models.Host{
|
||||
ID: uuid.New(),
|
||||
EndpointIP: net.ParseIP("192.168.1.1"),
|
||||
|
|
38
logic/hostactions/hostactions.go
Normal file
38
logic/hostactions/hostactions.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package hostactions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
// nodeActionHandler - handles the storage of host action updates
|
||||
var nodeActionHandler sync.Map
|
||||
|
||||
// AddAction - adds a host action to a host's list to be retrieved from broker update
|
||||
func AddAction(hu models.HostUpdate) {
|
||||
currentRecords, ok := nodeActionHandler.Load(hu.Host.ID.String())
|
||||
if !ok { // no list exists yet
|
||||
nodeActionHandler.Store(hu.Host.ID.String(), []models.HostUpdate{hu})
|
||||
} else { // list exists, append to it
|
||||
currentList := currentRecords.([]models.HostUpdate)
|
||||
currentList = append(currentList, hu)
|
||||
nodeActionHandler.Store(hu.Host.ID.String(), currentList)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAction - gets an action if exists
|
||||
// TODO: may need to move to DB rather than sync map for HA
|
||||
func GetAction(id string) *models.HostUpdate {
|
||||
currentRecords, ok := nodeActionHandler.Load(id)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
currentList := currentRecords.([]models.HostUpdate)
|
||||
if len(currentList) > 0 {
|
||||
hu := currentList[0]
|
||||
nodeActionHandler.Store(hu.Host.ID.String(), currentList[1:])
|
||||
return &hu
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -90,7 +90,7 @@ func CreateHost(h *models.Host) error {
|
|||
if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
|
||||
return ErrHostExists
|
||||
}
|
||||
//encrypt that password so we never see it
|
||||
// encrypt that password so we never see it
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -236,7 +236,12 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Nodes = append(h.Nodes, n.ID.String())
|
||||
currentHost, err := GetHost(h.ID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.HostPass = currentHost.HostPass
|
||||
h.Nodes = append(currentHost.Nodes, n.ID.String())
|
||||
return UpsertHost(h)
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ func GetProxyUpdateForHost(host *models.Host) (models.ProxyManagerPayload, error
|
|||
} else {
|
||||
logger.Log(0, "couldn't find relay host for: ", host.ID.String())
|
||||
}
|
||||
|
||||
}
|
||||
if host.IsRelay {
|
||||
relayedHosts := GetRelayedHosts(host)
|
||||
|
@ -142,9 +141,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
|
|||
if deletedNode != nil {
|
||||
deletedNodes = append(deletedNodes, *deletedNode)
|
||||
}
|
||||
logger.Log(1, "peer update for host ", host.ID.String())
|
||||
logger.Log(1, "peer update for host", host.ID.String())
|
||||
peerIndexMap := make(map[string]int)
|
||||
for _, nodeID := range host.Nodes {
|
||||
nodeID := nodeID
|
||||
node, err := GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
continue
|
||||
|
@ -163,7 +163,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
|
|||
}
|
||||
for _, peer := range currentPeers {
|
||||
peer := peer
|
||||
if peer.ID == node.ID {
|
||||
if peer.ID.String() == node.ID.String() {
|
||||
logger.Log(2, "peer update, skipping self")
|
||||
//skip yourself
|
||||
continue
|
||||
|
@ -185,7 +185,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
|
|||
continue
|
||||
}
|
||||
if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) {
|
||||
log.Println("peer update, skipping node for acl")
|
||||
logger.Log(2, "peer update, skipping node for acl")
|
||||
//skip if not permitted by acl
|
||||
continue
|
||||
}
|
||||
|
|
59
models/enrollment_key.go
Normal file
59
models/enrollment_key.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// EnrollmentToken - the tokenized version of an enrollmentkey;
|
||||
// to be used for host registration
|
||||
type EnrollmentToken struct {
|
||||
Server string `json:"server"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// EnrollmentKeyLength - the length of an enrollment key - 62^16 unique possibilities
|
||||
const EnrollmentKeyLength = 32
|
||||
|
||||
// EnrollmentKey - the key used to register hosts and join them to specific networks
|
||||
type EnrollmentKey struct {
|
||||
Expiration time.Time `json:"expiration"`
|
||||
UsesRemaining int `json:"uses_remaining"`
|
||||
Value string `json:"value"`
|
||||
Networks []string `json:"networks"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
Tags []string `json:"tags"`
|
||||
Token string `json:"token,omitempty"` // B64 value of EnrollmentToken
|
||||
}
|
||||
|
||||
// APIEnrollmentKey - used to create enrollment keys via API
|
||||
type APIEnrollmentKey struct {
|
||||
Expiration int64 `json:"expiration"`
|
||||
UsesRemaining int `json:"uses_remaining"`
|
||||
Networks []string `json:"networks"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// EnrollmentKey.IsValid - checks if the key is still valid to use
|
||||
func (k *EnrollmentKey) IsValid() bool {
|
||||
if k == nil {
|
||||
return false
|
||||
}
|
||||
if k.UsesRemaining > 0 {
|
||||
return true
|
||||
}
|
||||
if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) {
|
||||
return true
|
||||
}
|
||||
|
||||
return k.Unlimited
|
||||
}
|
||||
|
||||
// EnrollmentKey.Validate - validate's an EnrollmentKey
|
||||
// should be used during creation
|
||||
func (k *EnrollmentKey) Validate() bool {
|
||||
return k.Networks != nil &&
|
||||
k.Tags != nil &&
|
||||
len(k.Value) == EnrollmentKeyLength &&
|
||||
k.IsValid()
|
||||
}
|
|
@ -74,6 +74,10 @@ const (
|
|||
DeleteHost = "DELETE_HOST"
|
||||
// JoinHostToNetwork - constant for host network join action
|
||||
JoinHostToNetwork = "JOIN_HOST_TO_NETWORK"
|
||||
// Acknowledgement - ACK response for hosts
|
||||
Acknowledgement = "ACK"
|
||||
// RequestAck - request an ACK
|
||||
RequestAck = "REQ_ACK"
|
||||
)
|
||||
|
||||
// HostUpdate - struct for host update
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/hostactions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
|
@ -144,6 +145,19 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
|
|||
logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String()))
|
||||
var sendPeerUpdate bool
|
||||
switch hostUpdate.Action {
|
||||
case models.Acknowledgement:
|
||||
hu := hostactions.GetAction(currentHost.ID.String())
|
||||
if hu != nil {
|
||||
if err = HostUpdate(hu); err != nil {
|
||||
logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
|
||||
return
|
||||
} else {
|
||||
if err = PublishSingleHostPeerUpdate(currentHost, nil); err != nil {
|
||||
logger.Log(0, "failed peers publish after join acknowledged", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case models.UpdateHost:
|
||||
sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost)
|
||||
err := logic.UpsertHost(currentHost)
|
||||
|
@ -169,6 +183,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
|
|||
}
|
||||
sendPeerUpdate = true
|
||||
}
|
||||
|
||||
if sendPeerUpdate {
|
||||
err := PublishPeerUpdate()
|
||||
if err != nil {
|
||||
|
|
|
@ -61,6 +61,9 @@ func PublishSingleHostPeerUpdate(host *models.Host, deletedNode *models.Node) er
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(peerUpdate.Peers) == 0 { // no peers to send
|
||||
return nil
|
||||
}
|
||||
if host.ProxyEnabled {
|
||||
proxyUpdate, err := logic.GetProxyUpdateForHost(host)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue