mirror of
https://github.com/gravitl/netmaker.git
synced 2024-09-20 15:26:04 +08:00
feat(NET-688): auto relaying via enrollment keys (#2647)
* feat(NET-688): auto relaying via enrollment keys * feat(NET-688): address pr comments
This commit is contained in:
parent
75e110a5a6
commit
61ef6142ff
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// SessionHandler - called by the HTTP router when user
|
||||
|
@ -202,7 +203,7 @@ func SessionHandler(conn *websocket.Conn) {
|
|||
if err = conn.WriteMessage(messageType, reponseData); err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
}
|
||||
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host)
|
||||
go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
|
||||
case <-timeout: // the read from req.answerCh has timed out
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "error during timeout message writing:", err.Error())
|
||||
|
@ -221,7 +222,7 @@ func SessionHandler(conn *websocket.Conn) {
|
|||
}
|
||||
|
||||
// CheckNetRegAndHostUpdate - run through networks and send a host update
|
||||
func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
|
||||
func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) {
|
||||
// publish host update through MQ
|
||||
for i := range networks {
|
||||
network := networks[i]
|
||||
|
@ -231,6 +232,14 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
|
|||
logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
|
||||
continue
|
||||
}
|
||||
if relayNodeId != uuid.Nil && !newNode.IsRelayed {
|
||||
newNode.IsRelayed = true
|
||||
newNode.RelayedBy = relayNodeId.String()
|
||||
slog.Info(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
|
||||
if err := logic.UpsertNode(newNode); err != nil {
|
||||
slog.Error("failed to update node", "nodeid", relayNodeId.String())
|
||||
}
|
||||
}
|
||||
logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
|
||||
hostactions.AddAction(models.HostUpdate{
|
||||
Action: models.JoinHostToNetwork,
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
|
@ -26,6 +27,8 @@ func enrollmentKeyHandlers(r *mux.Router) {
|
|||
Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).
|
||||
Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey))).
|
||||
Methods(http.MethodPut)
|
||||
}
|
||||
|
||||
// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
|
||||
|
@ -113,12 +116,23 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
|||
newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
|
||||
}
|
||||
|
||||
relayId := uuid.Nil
|
||||
if enrollmentKeyBody.Relay != "" {
|
||||
relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "error parsing relay id: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newEnrollmentKey, err := logic.CreateEnrollmentKey(
|
||||
enrollmentKeyBody.UsesRemaining,
|
||||
newTime,
|
||||
enrollmentKeyBody.Networks,
|
||||
enrollmentKeyBody.Tags,
|
||||
enrollmentKeyBody.Unlimited,
|
||||
relayId,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
|
||||
|
@ -136,6 +150,57 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(newEnrollmentKey)
|
||||
}
|
||||
|
||||
// swagger:route PUT /api/v1/enrollment-keys/:id enrollmentKeys updateEnrollmentKey
|
||||
//
|
||||
// Updates an EnrollmentKey for hosts to use on Netmaker server. Updates only the relay to use.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
//
|
||||
// Responses:
|
||||
// 200: EnrollmentKey
|
||||
func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
var enrollmentKeyBody models.APIEnrollmentKey
|
||||
params := mux.Vars(r)
|
||||
keyId := params["keyID"]
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
|
||||
if err != nil {
|
||||
slog.Error("error decoding request body", "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
relayId := uuid.Nil
|
||||
if enrollmentKeyBody.Relay != "" {
|
||||
relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
|
||||
if err != nil {
|
||||
slog.Error("error parsing relay id", "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId)
|
||||
if err != nil {
|
||||
slog.Error("failed to update enrollment key", "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
|
||||
slog.Error("failed to update enrollment key", "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("updated enrollment key", "id", keyId)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(newEnrollmentKey)
|
||||
}
|
||||
|
||||
// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
|
||||
//
|
||||
// Handles a Netclient registration with server and add nodes accordingly.
|
||||
|
@ -286,5 +351,5 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(&response)
|
||||
// notify host of changes, peer and node updates
|
||||
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
|
||||
go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay)
|
||||
}
|
||||
|
|
|
@ -7,8 +7,10 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// EnrollmentErrors - struct for holding EnrollmentKey error messages
|
||||
|
@ -29,12 +31,12 @@ var EnrollmentErrors = struct {
|
|||
}
|
||||
|
||||
// CreateEnrollmentKey - creates a new enrollment key in db
|
||||
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
|
||||
func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
|
||||
newKeyID, err := getUniqueEnrollmentID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k = &models.EnrollmentKey{
|
||||
k := &models.EnrollmentKey{
|
||||
Value: newKeyID,
|
||||
Expiration: time.Time{},
|
||||
UsesRemaining: 0,
|
||||
|
@ -42,6 +44,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
|
|||
Networks: []string{},
|
||||
Tags: []string{},
|
||||
Type: models.Undefined,
|
||||
Relay: relay,
|
||||
}
|
||||
if uses > 0 {
|
||||
k.UsesRemaining = uses
|
||||
|
@ -61,10 +64,51 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
|
|||
if ok := k.Validate(); !ok {
|
||||
return nil, EnrollmentErrors.InvalidCreate
|
||||
}
|
||||
if relay != uuid.Nil {
|
||||
relayNode, err := GetNodeByID(relay.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !slices.Contains(k.Networks, relayNode.Network) {
|
||||
return nil, errors.New("relay node not in key's networks")
|
||||
}
|
||||
if !relayNode.IsRelay {
|
||||
return nil, errors.New("relay node is not a relay")
|
||||
}
|
||||
}
|
||||
if err = upsertEnrollmentKey(k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// UpdateEnrollmentKey - updates an existing enrollment key's associated relay
|
||||
func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) {
|
||||
key, err := GetEnrollmentKey(keyId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if relayId != uuid.Nil {
|
||||
relayNode, err := GetNodeByID(relayId.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !slices.Contains(key.Networks, relayNode.Network) {
|
||||
return nil, errors.New("relay node not in key's networks")
|
||||
}
|
||||
if !relayNode.IsRelay {
|
||||
return nil, errors.New("relay node is not a relay")
|
||||
}
|
||||
}
|
||||
|
||||
key.Relay = relayId
|
||||
|
||||
if err = upsertEnrollmentKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAllEnrollmentKeys - fetches all enrollment keys from DB
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -13,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
t.Run("Can_Not_Create_Key", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
|
||||
assert.Nil(t, newKey)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
|
||||
})
|
||||
t.Run("Can_Create_Key_Uses", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, newKey.UsesRemaining)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Time", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
})
|
||||
t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Networks) == 2)
|
||||
})
|
||||
t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
|
||||
newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.True(t, len(newKey.Tags) == 2)
|
||||
|
@ -61,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
|
|||
func TestDelete_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
|
||||
t.Run("Can_Delete_Key", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
err := DeleteEnrollmentKey(newKey.Value)
|
||||
|
@ -82,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
|
|||
func TestDecrement_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
|
||||
t.Run("Check_initial_uses", func(t *testing.T) {
|
||||
assert.True(t, newKey.IsValid())
|
||||
assert.Equal(t, newKey.UsesRemaining, 1)
|
||||
|
@ -106,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
|
|||
func TestUsability_EnrollmentKey(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
|
||||
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
|
||||
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
|
||||
key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
|
||||
key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil)
|
||||
key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
|
||||
t.Run("Check if valid use key can be used", func(t *testing.T) {
|
||||
assert.Equal(t, key1.UsesRemaining, 1)
|
||||
ok := TryToUseEnrollmentKey(key1)
|
||||
|
@ -144,7 +145,7 @@ func removeAllEnrollments() {
|
|||
func TestTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
|
||||
const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
|
||||
const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
|
@ -177,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
|
|||
func TestDeTokenize_EnrollmentKeys(t *testing.T) {
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
|
||||
newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
|
||||
const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
|
||||
const serverAddr = "api.myserver.com"
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package models
|
|||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -39,6 +41,7 @@ type EnrollmentKey struct {
|
|||
Tags []string `json:"tags"`
|
||||
Token string `json:"token,omitempty"` // B64 value of EnrollmentToken
|
||||
Type KeyType `json:"type"`
|
||||
Relay uuid.UUID `json:"relay"`
|
||||
}
|
||||
|
||||
// APIEnrollmentKey - used to create enrollment keys via API
|
||||
|
@ -49,6 +52,7 @@ type APIEnrollmentKey struct {
|
|||
Unlimited bool `json:"unlimited"`
|
||||
Tags []string `json:"tags"`
|
||||
Type KeyType `json:"type"`
|
||||
Relay string `json:"relay"`
|
||||
}
|
||||
|
||||
// RegisterResponse - the response to a successful enrollment register
|
||||
|
|
|
@ -160,4 +160,5 @@ type RegisterMsg struct {
|
|||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
JoinAll bool `json:"join_all,omitempty"`
|
||||
Relay string `json:"relay,omitempty"`
|
||||
}
|
||||
|
|
|
@ -149,6 +149,7 @@ func RelayUpdates(currentNode, newNode *models.Node) bool {
|
|||
return relayUpdates
|
||||
}
|
||||
|
||||
// UpdateRelayed - updates a relay's relayed nodes, and sends updates to the relayed nodes over MQ
|
||||
func UpdateRelayed(currentNode, newNode *models.Node) {
|
||||
updatenodes := updateRelayNodes(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
|
||||
if len(updatenodes) > 0 {
|
||||
|
|
Loading…
Reference in a new issue