From e40712063c015a97fd8f752ccd0d7d178237a04e Mon Sep 17 00:00:00 2001 From: Matthew R Kasun Date: Sun, 25 Apr 2021 08:18:43 -0400 Subject: [PATCH] add validation for node.Address --- controllers/common.go | 7 +- controllers/nodeHttpController.go | 1 - models/node.go | 195 ++++++++++++------------ test/api_test.go | 55 ++----- test/{group_test.go => network_test.go} | 40 +++-- test/node_test.go | 19 ++- test/user_test.go | 6 - 7 files changed, 155 insertions(+), 168 deletions(-) rename test/{group_test.go => network_test.go} (96%) diff --git a/controllers/common.go b/controllers/common.go index 0be76e30..3cdb70db 100644 --- a/controllers/common.go +++ b/controllers/common.go @@ -62,6 +62,11 @@ func ValidateNode(operation string, networkName string, node models.Node) error v := validator.New() + _ = v.RegisterValidation("address_check", func(fl validator.FieldLevel) bool { + isIpv4 := functions.IsIpv4Net(node.Address) + notEmptyCheck := node.Address != "" + return (notEmptyCheck && isIpv4) + }) _ = v.RegisterValidation("endpoint_check", func(fl validator.FieldLevel) bool { //var isFieldUnique bool = functions.IsFieldUnique(networkName, "endpoint", node.Endpoint) isIpv4 := functions.IsIpv4Net(node.Endpoint) @@ -193,6 +198,7 @@ func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) { // prepare update model. update := bson.D{ {"$set", bson.D{ + {"address", node.Address}, {"name", node.Name}, {"password", node.Password}, {"listenport", node.ListenPort}, @@ -212,7 +218,6 @@ func UpdateNode(nodechange models.Node, node models.Node) (models.Node, error) { }}, } var nodeupdate models.Node - errN := collection.FindOneAndUpdate(ctx, filter, update).Decode(&nodeupdate) if errN != nil { return nodeupdate, errN diff --git a/controllers/nodeHttpController.go b/controllers/nodeHttpController.go index 41407ddd..7dfcaddc 100644 --- a/controllers/nodeHttpController.go +++ b/controllers/nodeHttpController.go @@ -765,7 +765,6 @@ func updateNode(w http.ResponseWriter, r *http.Request) { if nodechange.MacAddress == "" { nodechange.MacAddress = node.MacAddress } - err = ValidateNode("update", params["network"], nodechange) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) diff --git a/models/node.go b/models/node.go index 99dce971..8b435e8e 100644 --- a/models/node.go +++ b/models/node.go @@ -1,152 +1,149 @@ package models import ( - "go.mongodb.org/mongo-driver/bson/primitive" - "github.com/gravitl/netmaker/mongoconn" - "math/rand" - "time" - "net" - "context" - "go.mongodb.org/mongo-driver/bson" + "context" + "math/rand" + "net" + "time" + + "github.com/gravitl/netmaker/mongoconn" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" ) const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" var seededRand *rand.Rand = rand.New( - rand.NewSource(time.Now().UnixNano())) + rand.NewSource(time.Now().UnixNano())) //node struct type Node struct { - ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"` - Address string `json:"address" bson:"address"` - LocalAddress string `json:"localaddress" bson:"localaddress" validate:"localaddress_check"` - Name string `json:"name" bson:"name" validate:"omitempty,name_valid,max=12"` - ListenPort int32 `json:"listenport" bson:"listenport" validate:"omitempty,numeric,min=1024,max=65535"` - PublicKey string `json:"publickey" bson:"publickey" validate:"pubkey_check"` - Endpoint string `json:"endpoint" bson:"endpoint" validate:"endpoint_check"` - PostUp string `json:"postup" bson:"postup"` - PostDown string `json:"postdown" bson:"postdown"` - AllowedIPs string `json:"allowedips" bson:"allowedips"` - PersistentKeepalive int32 `json:"persistentkeepalive" bson:"persistentkeepalive" validate: "omitempty,numeric,max=1000"` - SaveConfig *bool `json:"saveconfig" bson:"saveconfig"` - AccessKey string `json:"accesskey" bson:"accesskey"` - Interface string `json:"interface" bson:"interface"` - LastModified int64 `json:"lastmodified" bson:"lastmodified"` - KeyUpdateTimeStamp int64 `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"` - ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime"` - LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate"` - LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin"` - MacAddress string `json:"macaddress" bson:"macaddress" validate:"required,macaddress_valid,macaddress_unique"` - CheckInInterval int32 `json:"checkininterval" bson:"checkininterval"` - Password string `json:"password" bson:"password" validate:"password_check"` - Network string `json:"network" bson:"network" validate:"network_exists"` - IsPending bool `json:"ispending" bson:"ispending"` - IsGateway bool `json:"isgateway" bson:"isgateway"` - GatewayRange string `json:"gatewayrange" bson:"gatewayrange"` - PostChanges string `json:"postchanges" bson:"postchanges"` + ID primitive.ObjectID `json:"_id,omitempty" bson:"_id,omitempty"` + Address string `json:"address" bson:"address" validate:"address_check"` + LocalAddress string `json:"localaddress" bson:"localaddress" validate:"localaddress_check"` + Name string `json:"name" bson:"name" validate:"omitempty,name_valid,max=12"` + ListenPort int32 `json:"listenport" bson:"listenport" validate:"omitempty,numeric,min=1024,max=65535"` + PublicKey string `json:"publickey" bson:"publickey" validate:"pubkey_check"` + Endpoint string `json:"endpoint" bson:"endpoint" validate:"endpoint_check"` + PostUp string `json:"postup" bson:"postup"` + PostDown string `json:"postdown" bson:"postdown"` + AllowedIPs string `json:"allowedips" bson:"allowedips"` + PersistentKeepalive int32 `json:"persistentkeepalive" bson:"persistentkeepalive" validate: "omitempty,numeric,max=1000"` + SaveConfig *bool `json:"saveconfig" bson:"saveconfig"` + AccessKey string `json:"accesskey" bson:"accesskey"` + Interface string `json:"interface" bson:"interface"` + LastModified int64 `json:"lastmodified" bson:"lastmodified"` + KeyUpdateTimeStamp int64 `json:"keyupdatetimestamp" bson:"keyupdatetimestamp"` + ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime"` + LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate"` + LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin"` + MacAddress string `json:"macaddress" bson:"macaddress" validate:"required,macaddress_valid,macaddress_unique"` + CheckInInterval int32 `json:"checkininterval" bson:"checkininterval"` + Password string `json:"password" bson:"password" validate:"password_check"` + Network string `json:"network" bson:"network" validate:"network_exists"` + IsPending bool `json:"ispending" bson:"ispending"` + IsGateway bool `json:"isgateway" bson:"isgateway"` + GatewayRange string `json:"gatewayrange" bson:"gatewayrange"` + PostChanges string `json:"postchanges" bson:"postchanges"` } - //TODO: Contains a fatal error return. Need to change //Used in contexts where it's not the Parent network. -func(node *Node) GetNetwork() (Network, error){ +func (node *Node) GetNetwork() (Network, error) { - var network Network + var network Network - collection := mongoconn.NetworkDB - //collection := mongoconn.Client.Database("netmaker").Collection("networks") + collection := mongoconn.NetworkDB + //collection := mongoconn.Client.Database("netmaker").Collection("networks") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - filter := bson.M{"netid": node.Network} - err := collection.FindOne(ctx, filter).Decode(&network) + filter := bson.M{"netid": node.Network} + err := collection.FindOne(ctx, filter).Decode(&network) - defer cancel() + defer cancel() - if err != nil { - //log.Fatal(err) - return network, err - } + if err != nil { + //log.Fatal(err) + return network, err + } - return network, err + return network, err } - //TODO: //Not sure if below two methods are necessary. May want to revisit -func(node *Node) SetLastModified(){ +func (node *Node) SetLastModified() { node.LastModified = time.Now().Unix() } -func(node *Node) SetLastCheckIn(){ - node.LastCheckIn = time.Now().Unix() +func (node *Node) SetLastCheckIn() { + node.LastCheckIn = time.Now().Unix() } -func(node *Node) SetLastPeerUpdate(){ - node.LastPeerUpdate = time.Now().Unix() +func (node *Node) SetLastPeerUpdate() { + node.LastPeerUpdate = time.Now().Unix() } -func(node *Node) SetExpirationDateTime(){ - node.ExpirationDateTime = time.Unix(33174902665, 0).Unix() +func (node *Node) SetExpirationDateTime() { + node.ExpirationDateTime = time.Unix(33174902665, 0).Unix() } - -func(node *Node) SetDefaultName(){ - if node.Name == "" { - nodeid := StringWithCharset(5, charset) - nodename := "node-" + nodeid - node.Name = nodename - } +func (node *Node) SetDefaultName() { + if node.Name == "" { + nodeid := StringWithCharset(5, charset) + nodename := "node-" + nodeid + node.Name = nodename + } } //TODO: I dont know why this exists //This should exist on the node.go struct. I'm sure there was a reason? -func(node *Node) SetDefaults() { +func (node *Node) SetDefaults() { - //TODO: Maybe I should make Network a part of the node struct. Then we can just query the Network object for stuff. - parentNetwork, _ := node.GetNetwork() + //TODO: Maybe I should make Network a part of the node struct. Then we can just query the Network object for stuff. + parentNetwork, _ := node.GetNetwork() - node.ExpirationDateTime = time.Unix(33174902665, 0).Unix() + node.ExpirationDateTime = time.Unix(33174902665, 0).Unix() - if node.ListenPort == 0 { - node.ListenPort = parentNetwork.DefaultListenPort - } - if node.PostDown == "" { - //Empty because we dont set it - //may want to set it to something in the future - } - //TODO: This is dumb and doesn't work - //Need to change - if node.SaveConfig == nil { - defaultsave := *parentNetwork.DefaultSaveConfig - node.SaveConfig = &defaultsave - } - if node.Interface == "" { - node.Interface = parentNetwork.DefaultInterface - } - if node.PersistentKeepalive == 0 { - node.PersistentKeepalive = parentNetwork.DefaultKeepalive - } - if node.PostUp == "" { - postup := parentNetwork.DefaultPostUp - node.PostUp = postup - } - node.CheckInInterval = parentNetwork.DefaultCheckInInterval + if node.ListenPort == 0 { + node.ListenPort = parentNetwork.DefaultListenPort + } + if node.PostDown == "" { + //Empty because we dont set it + //may want to set it to something in the future + } + //TODO: This is dumb and doesn't work + //Need to change + if node.SaveConfig == nil { + defaultsave := *parentNetwork.DefaultSaveConfig + node.SaveConfig = &defaultsave + } + if node.Interface == "" { + node.Interface = parentNetwork.DefaultInterface + } + if node.PersistentKeepalive == 0 { + node.PersistentKeepalive = parentNetwork.DefaultKeepalive + } + if node.PostUp == "" { + postup := parentNetwork.DefaultPostUp + node.PostUp = postup + } + node.CheckInInterval = parentNetwork.DefaultCheckInInterval } func StringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) } //Check for valid IPv4 address //Note: We dont handle IPv6 AT ALL!!!!! This definitely is needed at some point //But for iteration 1, lets just stick to IPv4. Keep it simple stupid. func IsIpv4Net(host string) bool { - return net.ParseIP(host) != nil + return net.ParseIP(host) != nil } - diff --git a/test/api_test.go b/test/api_test.go index eee851ec..647db459 100644 --- a/test/api_test.go +++ b/test/api_test.go @@ -17,11 +17,6 @@ import ( "github.com/stretchr/testify/assert" ) -type databaseError struct { - Inner *int - Errors int -} - //should be use models.SuccessResponse and models.SuccessfulUserLoginResponse //rather than creating new type but having trouble decoding that way type Auth struct { @@ -96,7 +91,6 @@ func api(t *testing.T, data interface{}, method, url, authorization string) (*ht request.Header.Set("authorization", "Bearer "+authorization) } client := http.Client{} - //t.Log("api request", request) return client.Do(request) } @@ -198,28 +192,6 @@ func deleteKey(t *testing.T, key, network string) { //assert.Equal(t, int64(1), message.DeletedCount) } -func networkExists(t *testing.T) bool { - response, err := api(t, "", http.MethodGet, baseURL+"/api/networks", "secretkey") - assert.Nil(t, err, err) - defer response.Body.Close() - assert.Equal(t, http.StatusOK, response.StatusCode) - err = json.NewDecoder(response.Body).Decode(&Networks) - assert.Nil(t, err, err) - for i, network := range Networks { - t.Log(i, network) - if network.NetID == "" { - return false - } else { - return true - } - } - return false -} - -func TestJunk(t *testing.T) { - deleteNetworks(t) -} - func deleteNetworks(t *testing.T) { //delete all node deleteAllNodes(t) @@ -237,21 +209,6 @@ func deleteNetworks(t *testing.T) { } } -func getNetworkNodes(t *testing.T) []models.ReturnNode { - var nodes []models.ReturnNode - //var node models.ReturnNode - //response, err := api(t, "", http.MethodGet, baseURL+"/api/nodes/skynet", "secretkey") - //assert.Nil(t, err, err) - //assert.Equal(t, http.StatusOK, response.StatusCode) - //defer response.Body.Close() - //err = json.NewDecoder(response.Body).Decode(&nodes) - //assert.Nil(t, err, err) - //for _, nodes := range nodes { - // nodes = append(nodes, node) - //} - return nodes -} - func deleteNode(t *testing.T) { response, err := api(t, "", http.MethodDelete, baseURL+"/api/nodes/skynet/01:02:03:04:05:06", "secretkey") assert.Nil(t, err, err) @@ -273,6 +230,7 @@ func deleteAllNodes(t *testing.T) { func createNode(t *testing.T) { var node models.Node key := createAccessKey(t) + node.Address = "10.71.0.1" node.AccessKey = key.Value node.MacAddress = "01:02:03:04:05:06" node.Name = "myNode" @@ -296,6 +254,17 @@ func getNode(t *testing.T) models.Node { return node } +func getNetwork(t *testing.T, network string) models.Network { + var net models.Network + response, err := api(t, "", http.MethodGet, baseURL+"/api/networks/"+network, "secretkey") + assert.Nil(t, err, err) + assert.Equal(t, http.StatusOK, response.StatusCode) + defer response.Body.Close() + err = json.NewDecoder(response.Body).Decode(&net) + assert.Nil(t, err, err) + return net +} + func setup(t *testing.T) { deleteNetworks(t) createNetwork(t) diff --git a/test/group_test.go b/test/network_test.go similarity index 96% rename from test/group_test.go rename to test/network_test.go index 5907894a..bcb20921 100644 --- a/test/group_test.go +++ b/test/network_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/http" "testing" + "time" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" @@ -15,9 +16,7 @@ func TestCreateNetwork(t *testing.T) { network := models.Network{} network.NetID = "skynet" network.AddressRange = "10.71.0.0/16" - if networkExists(t) { - deleteNetworks(t) - } + deleteNetworks(t) t.Run("InvalidToken", func(t *testing.T) { response, err := api(t, network, http.MethodPost, baseURL+"/api/networks", "badkey") assert.Nil(t, err, err) @@ -140,8 +139,6 @@ func TestDeleteNetwork(t *testing.T) { }) t.Run("NodesExist", func(t *testing.T) { setup(t) - node := getNode(t) - t.Log(node) response, err := api(t, "", http.MethodDelete, baseURL+"/api/networks/skynet", "secretkey") assert.Nil(t, err, err) assert.Equal(t, http.StatusForbidden, response.StatusCode) @@ -523,7 +520,7 @@ func TestUpdateNetwork(t *testing.T) { DefaultKeepAlive int32 } var network Network - network.DefaultKeepAlive = 1001 + network.DefaultKeepAlive = 2000 response, err := api(t, network, http.MethodPut, baseURL+"/api/networks/skynet", "secretkey") assert.Nil(t, err, err) var message models.ErrorResponse @@ -534,21 +531,26 @@ func TestUpdateNetwork(t *testing.T) { assert.Equal(t, http.StatusBadRequest, response.StatusCode) }) t.Run("UpdateSaveConfig", func(t *testing.T) { - t.Skip() - //does not appear to be updatable + //t.Skip() + //not updatable, ensure attempt to change does not result in change type Network struct { DefaultSaveConfig *bool } var network Network - value := false + var value bool + oldnet := getNetwork(t, "skynet") + if *oldnet.DefaultSaveConfig == true { + value = false + } else { + value = true + } + network.DefaultSaveConfig = &value response, err := api(t, network, http.MethodPut, baseURL+"/api/networks/skynet", "secretkey") assert.Nil(t, err, err) assert.Equal(t, http.StatusOK, response.StatusCode) - defer response.Body.Close() - err = json.NewDecoder(response.Body).Decode(&returnedNetwork) - assert.Nil(t, err, err) - assert.Equal(t, *network.DefaultSaveConfig, *returnedNetwork.DefaultSaveConfig) + newnet := getNetwork(t, "skynet") + assert.Equal(t, oldnet.DefaultSaveConfig, newnet.DefaultSaveConfig) }) t.Run("UpdateManualSignUP", func(t *testing.T) { type Network struct { @@ -612,3 +614,15 @@ func TestUpdateNetwork(t *testing.T) { assert.Equal(t, network.DefaultListenPort, returnedNetwork.DefaultListenPort) }) } + +func TestKeyUpdate(t *testing.T) { + //get current network settings + oldnet := getNetwork(t, "skynet") + //update key + time.Sleep(time.Second * 1) + reply, err := api(t, "", http.MethodPost, baseURL+"/api/networks/skynet/keyupdate", "secretkey") + assert.Nil(t, err, err) + assert.Equal(t, http.StatusOK, reply.StatusCode) + newnet := getNetwork(t, "skynet") + assert.Greater(t, newnet.KeyUpdateTimeStamp, oldnet.KeyUpdateTimeStamp) +} diff --git a/test/node_test.go b/test/node_test.go index 0b596baf..f0d33ddb 100644 --- a/test/node_test.go +++ b/test/node_test.go @@ -104,6 +104,7 @@ func TestUpdateNode(t *testing.T) { setup(t) t.Run("UpdateMulti", func(t *testing.T) { + data.Address = "10.1.0.2" data.MacAddress = "01:02:03:04:05:05" data.Name = "NewName" data.PublicKey = "DM5qhLAE20PG9BbfBCgfr+Ac9D2NDOwCtY1rbYDLf34=" @@ -115,7 +116,6 @@ func TestUpdateNode(t *testing.T) { assert.Equal(t, http.StatusOK, response.StatusCode) defer response.Body.Close() var node models.Node - t.Log(response.Body) err = json.NewDecoder(response.Body).Decode(&node) assert.Nil(t, err, err) assert.Equal(t, data.Name, node.Name) @@ -125,6 +125,19 @@ func TestUpdateNode(t *testing.T) { assert.Equal(t, data.LocalAddress, node.LocalAddress) assert.Equal(t, data.Endpoint, node.Endpoint) }) + t.Run("InvalidAddress", func(t *testing.T) { + data.Address = "10.300.2.0" + response, err := api(t, data, http.MethodPut, baseURL+"/api/nodes/skynet/01:02:03:04:05:05", "secretkey") + assert.Nil(t, err, err) + assert.Equal(t, http.StatusBadRequest, response.StatusCode) + var message models.ErrorResponse + defer response.Body.Close() + err = json.NewDecoder(response.Body).Decode(&message) + assert.Nil(t, err, err) + assert.Equal(t, http.StatusBadRequest, message.Code) + assert.Contains(t, message.Message, "Field validation for 'Address' failed") + }) + t.Run("InvalidMacAddress", func(t *testing.T) { data.MacAddress = "10:11:12:13:14:15:16" response, err := api(t, data, http.MethodPut, baseURL+"/api/nodes/skynet/01:02:03:04:05:05", "secretkey") @@ -226,7 +239,6 @@ func TestDeleteNode(t *testing.T) { assert.Nil(t, err, err) assert.Equal(t, "01:02:03:04:05:06 deleted.", message.Message) assert.Equal(t, http.StatusOK, message.Code) - t.Log(response.Header.Get("Content-Type")) }) t.Run("NonExistantNode", func(t *testing.T) { response, err := api(t, "", http.MethodDelete, baseURL+"/api/nodes/skynet/01:02:03:04:05:06", "secretkey") @@ -335,7 +347,6 @@ func TestUncordonNode(t *testing.T) { err = json.NewDecoder(response.Body).Decode(&message) assert.Nil(t, err, err) assert.Equal(t, "SUCCESS", message) - t.Log(message, string(message)) } func TestCreateNode(t *testing.T) { @@ -525,9 +536,7 @@ func TestCreateNode(t *testing.T) { err = json.NewDecoder(response.Body).Decode(&message) assert.Nil(t, err, err) assert.Equal(t, node.Name, message.Name) - t.Log(message) }) - } func TestGetLastModified(t *testing.T) { diff --git a/test/user_test.go b/test/user_test.go index e3971c28..3b019b04 100644 --- a/test/user_test.go +++ b/test/user_test.go @@ -2,7 +2,6 @@ package main import ( "encoding/json" - "io/ioutil" "net/http" "testing" @@ -28,8 +27,6 @@ func TestAdminCreation(t *testing.T) { assert.Equal(t, true, user.IsAdmin) assert.Equal(t, http.StatusOK, response.StatusCode) assert.True(t, adminExists(t), "Admin creation failed") - message, _ := ioutil.ReadAll(response.Body) - t.Log(string(message)) }) t.Run("AdminCreationFailure", func(t *testing.T) { if !adminExists(t) { @@ -50,10 +47,7 @@ func TestAdminCreation(t *testing.T) { func TestGetUser(t *testing.T) { if !adminExists(t) { - t.Log("no admin - creating") addAdmin(t) - } else { - t.Log("admin exists") } t.Run("GetUserWithValidToken", func(t *testing.T) { token, err := authenticate(t)