From 307a3d1e4b4e90a3c3e84d0f2e8f5031e458be2b Mon Sep 17 00:00:00 2001 From: Abhishek K Date: Wed, 21 May 2025 12:50:21 +0530 Subject: [PATCH] NET-1932: Merge egress and internet gateways (#3436) * feat: api access tokens * revoke all user tokens * redefine access token api routes, add auto egress option to enrollment keys * add server settings apis, add db table for settigs * handle server settings updates * switch to using settings from DB * fix sever settings migration * revet force migration for settings * fix server settings database write * egress model * fix revoked tokens to be unauthorized * update egress model * remove unused functions * convert access token to sql schema * switch access token to sql schema * fix merge conflicts * fix server settings types * bypass basic auth setting for super admin * add TODO comment * setup api handlers for egress revamp * use single DB, fix update nat boolean field * extend validaiton checks for egress ranges * add migration to convert to new egress model * fix panic interface conversion * publish peer update on settings update * revoke token generated by an user * add user token creation restriction by user role * add forbidden check for access token creation * revoke user token when group or role is changed * add default group to admin users on update * chore(go): import style changes from migration branch; 1. Singular file names for table schema. 2. No table name method. 3. Use .Model instead of .Table. 4. No unnecessary tagging. * remove nat check on egress gateway request * Revert "remove nat check on egress gateway request" This reverts commit 0aff12a189828fc4ccb4594adf7a3eb8772560f2. * remove nat check on egress gateway request * feat(go): add db middleware; * feat(go): restore method; * feat(go): add user access token schema; * add inet gw status to egress model * fetch node ids in the tag, add inet gw info clients * add inet gw info to node from egress list * add migration logic internet gws * create default acl policies * add egress info * add egress TODO * add egress TODO * fix user auth api: * add reference id to acl policy * add egress response from DB * publish peer update on egress changes * re initalise oauth and email config * set verbosity * normalise cidr on egress req * add egress id to acl group * change acls to use egress id * resolve merge conflicts * fix egress reference errors * move egress model to schema * add api context to DB * sync auto update settings with hosts * sync auto update settings with hosts * check acl for egress node * check for egress policy in the acl dst groups * fix acl rules for egress policies with new models * add status to egress model * fix inet node func * mask secret and convert jwt duration to minutes * enable egress policies on creation * convert jwt duration to minutes * add relevant ranges to inet egress * skip non active egress routes * resolve merge conflicts * fix static check * update gorm tag for primary key on egress model * create user policies for egress resources * resolve merge conflicts * get egress info on failover apis, add egress src validation for inet gws * add additional validation checks on egress req * add additional validation checks on egress req * skip all resources for inet policy * delete associated egress acl policies * fix failover of inetclient * avoid setting inet client asd inet gw * fix all resource egress policy * fix inet gw egress rule * check for node egress on relay req * fix egress acl rules comms * add new field for egress info on node * check acl policy in failover ctx * avoid default host to be set as inet client * fix relayed egress node * add valid error messaging for egress validate func * return if inet default host * jump port detection to 51821 * check host ports on pull * check user access gws via acls * add validation check for default host and failover for inet clients * add error messaging for acl policy check * fix inet gw status * ignore failover req for peer using inet gw * check for allowed egress ranges for a peer * add egress routes to static nodes by access * avoid setting failvoer as inet client * fix egress error messaging * fix extclients egress comms * fix inet gw acting as inet client * return formatted error on update acl validation * add default route for static nodes on inetclient * check relay node acting as inetclient * move inet node info to separate field, fix all resouces policy * remove debug logs --------- Co-authored-by: Vishal Dalwadi --- auth/host_session.go | 2 +- controllers/acls.go | 11 +- controllers/controller.go | 1 + controllers/egress.go | 257 +++++++++++++++++ controllers/enrollmentkeys.go | 2 +- controllers/ext_client.go | 9 +- controllers/hosts.go | 2 +- controllers/migrate.go | 2 +- controllers/node.go | 9 +- controllers/node_test.go | 94 ------- controllers/tags.go | 2 +- go.mod | 3 + go.sum | 16 ++ logic/acls.go | 514 ++++++++++++++++++++++++---------- logic/auth.go | 2 +- logic/egress.go | 366 ++++++++++++++++++++++++ logic/extpeers.go | 35 ++- logic/gateway.go | 42 ++- logic/hosts.go | 21 +- logic/jwts.go | 4 +- logic/networks.go | 2 - logic/nodes.go | 52 +++- logic/peers.go | 94 ++++--- logic/relay.go | 10 +- logic/tags.go | 2 +- logic/wireguard.go | 14 - migrate/migrate.go | 217 ++++++++++++++ models/accessToken.go | 60 ++++ models/acl.go | 1 + models/api_node.go | 21 +- models/egress.go | 14 + models/mqtt.go | 2 +- models/node.go | 26 +- models/structs.go | 1 + pro/controllers/failover.go | 34 ++- pro/controllers/inet_gws.go | 4 +- pro/controllers/users.go | 8 +- pro/initialize.go | 1 + pro/logic/failover.go | 5 +- pro/logic/nodes.go | 36 ++- pro/logic/user_mgmt.go | 22 +- schema/activity.go | 4 + schema/egress.go | 70 +++++ schema/models.go | 1 + 44 files changed, 1651 insertions(+), 444 deletions(-) create mode 100644 controllers/egress.go create mode 100644 logic/egress.go create mode 100644 models/accessToken.go create mode 100644 models/egress.go create mode 100644 schema/activity.go create mode 100644 schema/egress.go diff --git a/auth/host_session.go b/auth/host_session.go index d364ec5f..af5d69c7 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -165,7 +165,7 @@ func SessionHandler(conn *websocket.Conn) { return } } - logic.CheckHostPorts(&result.Host) + _ = logic.CheckHostPorts(&result.Host) if err := logic.CreateHost(&result.Host); err != nil { handleHostRegErr(conn, err) return diff --git a/controllers/acls.go b/controllers/acls.go index 9f958605..a89d20eb 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -51,7 +51,7 @@ func aclPolicyTypes(w http.ResponseWriter, r *http.Request) { DstGroupTypes: []models.AclGroupType{ models.NodeTagID, models.NodeID, - models.EgressRange, + models.EgressID, // models.NetmakerIPAclID, // models.NetmakerSubNetRangeAClID, }, @@ -171,6 +171,7 @@ func aclDebug(w http.ResponseWriter, r *http.Request) { IsPeerAllowed bool Policies []models.Acl IngressRules []models.FwRule + NodeAllPolicy bool } allowed, ps := logic.IsNodeAllowedToCommunicateV1(node, peer, true) @@ -253,8 +254,8 @@ func createAcl(w http.ResponseWriter, r *http.Request) { acl.Proto = models.ALL } // validate create acl policy - if !logic.IsAclPolicyValid(acl) { - logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest")) + if err := logic.IsAclPolicyValid(acl); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } err = logic.InsertAcl(acl) @@ -292,8 +293,8 @@ func updateAcl(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - if !logic.IsAclPolicyValid(updateAcl.Acl) { - logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest")) + if err := logic.IsAclPolicyValid(updateAcl.Acl); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } if updateAcl.Acl.NetworkID != acl.NetworkID { diff --git a/controllers/controller.go b/controllers/controller.go index 5c97505c..7fffd0bd 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -39,6 +39,7 @@ var HttpHandlers = []interface{}{ enrollmentKeyHandlers, tagHandlers, aclHandlers, + egressHandlers, legacyHandlers, } diff --git a/controllers/egress.go b/controllers/egress.go new file mode 100644 index 00000000..90e206dc --- /dev/null +++ b/controllers/egress.go @@ -0,0 +1,257 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" +) + +func egressHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(createEgress))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(listEgress))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(updateEgress))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(deleteEgress))).Methods(http.MethodDelete) +} + +// @Summary Create Egress Resource +// @Router /api/v1/egress [post] +// @Tags Auth +// @Accept json +// @Param body body models.Egress +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func createEgress(w http.ResponseWriter, r *http.Request) { + + var req models.EgressReq + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var egressRange string + if !req.IsInetGw { + egressRange, err = logic.NormalizeCIDR(req.Range) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + } else { + egressRange = "*" + } + + e := schema.Egress{ + ID: uuid.New().String(), + Name: req.Name, + Network: req.Network, + Description: req.Description, + Range: egressRange, + Nat: req.Nat, + IsInetGw: req.IsInetGw, + Nodes: make(datatypes.JSONMap), + Tags: make(datatypes.JSONMap), + Status: true, + CreatedBy: r.Header.Get("user"), + CreatedAt: time.Now().UTC(), + } + for nodeID, metric := range req.Nodes { + e.Nodes[nodeID] = metric + } + if err := logic.ValidateEgressReq(&e); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + err = e.Create(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error creating egress resource"+err.Error()), "internal"), + ) + return + } + // for nodeID := range e.Nodes { + // node, err := logic.GetNodeByID(nodeID) + // if err != nil { + // logic.AddEgressInfoToNode(&node, e) + // logic.UpsertNode(&node) + // } + + // } + go mq.PublishPeerUpdate(false) + logic.ReturnSuccessResponseWithJson(w, r, e, "created egress resource") +} + +// @Summary List Egress Resource +// @Router /api/v1/egress [get] +// @Tags Auth +// @Accept json +// @Param query network string +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listEgress(w http.ResponseWriter, r *http.Request) { + + network := r.URL.Query().Get("network") + if network == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), "badrequest")) + return + } + e := schema.Egress{Network: network} + list, err := e.ListByNetwork(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error listing egress resource"+err.Error()), "internal"), + ) + return + } + logic.ReturnSuccessResponseWithJson(w, r, list, "fetched egress resource list") +} + +// @Summary Update Egress Resource +// @Router /api/v1/egress [put] +// @Tags Auth +// @Accept json +// @Param body body models.Egress +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func updateEgress(w http.ResponseWriter, r *http.Request) { + + var req models.EgressReq + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var egressRange string + if !req.IsInetGw { + egressRange, err = logic.NormalizeCIDR(req.Range) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + } else { + egressRange = "*" + } + + e := schema.Egress{ID: req.ID} + err = e.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var updateNat bool + var updateInetGw bool + var updateStatus bool + if req.Nat != e.Nat { + updateNat = true + } + if req.IsInetGw != e.IsInetGw { + updateInetGw = true + } + if req.Status != e.Status { + updateStatus = true + } + e.Nodes = make(datatypes.JSONMap) + e.Tags = make(datatypes.JSONMap) + for nodeID, metric := range req.Nodes { + e.Nodes[nodeID] = metric + } + e.Range = egressRange + e.Description = req.Description + e.Name = req.Name + e.Nat = req.Nat + e.Status = req.Status + e.IsInetGw = req.IsInetGw + e.UpdatedAt = time.Now().UTC() + if err := logic.ValidateEgressReq(&e); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + err = e.Update(db.WithContext(context.TODO())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error creating egress resource"+err.Error()), "internal"), + ) + return + } + if updateNat { + e.Nat = req.Nat + e.UpdateNatStatus(db.WithContext(context.TODO())) + } + if updateInetGw { + e.IsInetGw = req.IsInetGw + e.UpdateINetGwStatus(db.WithContext(context.TODO())) + } + if updateStatus { + e.Status = req.Status + e.UpdateEgressStatus(db.WithContext(context.TODO())) + } + go mq.PublishPeerUpdate(false) + logic.ReturnSuccessResponseWithJson(w, r, e, "updated egress resource") +} + +// @Summary Delete Egress Resource +// @Router /api/v1/egress [delete] +// @Tags Auth +// @Accept json +// @Param body body models.Egress +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func deleteEgress(w http.ResponseWriter, r *http.Request) { + + id := r.URL.Query().Get("id") + if id == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest")) + return + } + e := schema.Egress{ID: id} + err := e.Delete(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + // delete related acl policies + acls := logic.ListAcls() + for _, acl := range acls { + + for i := len(acl.Dst) - 1; i >= 0; i-- { + if acl.Dst[i].ID == models.EgressID && acl.Dst[i].Value == id { + acl.Dst = append(acl.Dst[:i], acl.Dst[i+1:]...) + } + } + if len(acl.Dst) == 0 { + logic.DeleteAcl(acl) + } else { + logic.UpsertAcl(acl) + } + } + go mq.PublishPeerUpdate(false) + logic.ReturnSuccessResponseWithJson(w, r, nil, "deleted egress resource") +} diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 5ab2550e..313736a2 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -302,7 +302,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { if !hostExists { newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive // register host - logic.CheckHostPorts(&newHost) + _ = logic.CheckHostPorts(&newHost) // create EMQX credentials and ACLs for host if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { if err := mq.GetEmqxHandler().CreateEmqxUser(newHost.ID.String(), newHost.HostPass); err != nil { diff --git a/controllers/ext_client.go b/controllers/ext_client.go index 264dc846..ad8b51aa 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -174,6 +174,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + logic.GetNodeEgressInfo(&gwnode) host, err := logic.GetHost(gwnode.HostID.String()) if err != nil { logger.Log( @@ -261,7 +262,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { } var newAllowedIPs string - if logic.IsInternetGw(gwnode) || gwnode.InternetGwID != "" { + if logic.IsInternetGw(gwnode) || gwnode.EgressDetails.InternetGwID != "" { egressrange := "0.0.0.0/0" if gwnode.Address6.IP != nil && client.Address6 != "" { egressrange += "," + "::/0" @@ -540,7 +541,7 @@ func getExtClientHAConf(w http.ResponseWriter, r *http.Request) { keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive)) } var newAllowedIPs string - if logic.IsInternetGw(gwnode) || gwnode.InternetGwID != "" { + if logic.IsInternetGw(gwnode) || gwnode.EgressDetails.InternetGwID != "" { egressrange := "0.0.0.0/0" if gwnode.Address6.IP != nil && client.Address6 != "" { egressrange += "," + "::/0" @@ -688,7 +689,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { var gateway models.EgressGatewayRequest gateway.NetID = params["network"] gateway.Ranges = customExtClient.ExtraAllowedIPs - err := logic.ValidateEgressRange(gateway) + err := logic.ValidateEgressRange(gateway.NetID, gateway.Ranges) if err != nil { logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) @@ -876,7 +877,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { var gateway models.EgressGatewayRequest gateway.NetID = params["network"] gateway.Ranges = update.ExtraAllowedIPs - err = logic.ValidateEgressRange(gateway) + err = logic.ValidateEgressRange(gateway.NetID, gateway.Ranges) if err != nil { logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) diff --git a/controllers/hosts.go b/controllers/hosts.go index 5c681ca5..5bc71857 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -216,7 +216,7 @@ func pull(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(keyErr, "internal")) return } - + _ = logic.CheckHostPorts(host) serverConf.TrafficKey = key response := models.HostPull{ Host: *host, diff --git a/controllers/migrate.go b/controllers/migrate.go index 7eaab859..69e8046d 100644 --- a/controllers/migrate.go +++ b/controllers/migrate.go @@ -208,7 +208,7 @@ func convertLegacyNode(legacy models.LegacyNode, hostID uuid.UUID) models.Node { node.IsRelay = false node.RelayedNodes = []string{} node.DNSOn = models.ParseBool(legacy.DNSOn) - node.LastModified = time.Now() + node.LastModified = time.Now().UTC() node.ExpirationDateTime = time.Unix(legacy.ExpirationDateTime, 0) node.EgressGatewayNatEnabled = models.ParseBool(legacy.EgressGatewayNatEnabled) node.EgressGatewayRequest = legacy.EgressGatewayRequest diff --git a/controllers/node.go b/controllers/node.go index ee98165a..79e15bd4 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -516,7 +516,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) { } gateway.NetID = params["network"] gateway.NodeID = params["nodeid"] - err = logic.ValidateEgressRange(gateway) + err = logic.ValidateEgressRange(gateway.NetID, gateway.Ranges) if err != nil { logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) @@ -638,13 +638,6 @@ func updateNode(w http.ResponseWriter, r *http.Request) { ) return } - if newNode.IsInternetGateway != currentNode.IsInternetGateway { - if newNode.IsInternetGateway { - logic.SetInternetGw(newNode, models.InetNodeReq{}) - } else { - logic.UnsetInternetGw(newNode) - } - } relayUpdate := logic.RelayUpdates(¤tNode, newNode) if relayUpdate && newNode.IsRelay { err = logic.ValidateRelay(models.RelayRequest{ diff --git a/controllers/node_test.go b/controllers/node_test.go index 02495fd2..188e1fef 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -18,100 +18,6 @@ import ( var nonLinuxHost models.Host var linuxHost models.Host -func TestCreateEgressGateway(t *testing.T) { - var gateway models.EgressGatewayRequest - gateway.Ranges = []string{"10.100.100.0/24"} - gateway.RangesWithMetric = append(gateway.RangesWithMetric, models.EgressRangeMetric{ - Network: "10.100.100.0/24", - RouteMetric: 256, - }) - gateway.NetID = "skynet" - deleteAllNetworks() - createNet() - t.Run("NoNodes", func(t *testing.T) { - node, err := logic.CreateEgressGateway(gateway) - assert.Equal(t, models.Node{}, node) - assert.EqualError(t, err, "could not find any records") - }) - t.Run("Non-linux node", func(t *testing.T) { - createnode := createNodeWithParams("", "") - createNodeHosts() - createnode.HostID = nonLinuxHost.ID - err := logic.AssociateNodeToHost(createnode, &nonLinuxHost) - assert.Nil(t, err) - gateway.NodeID = createnode.ID.String() - node, err := logic.CreateEgressGateway(gateway) - assert.Equal(t, models.Node{}, node) - assert.EqualError(t, err, "windows is unsupported for egress gateways") - }) - t.Run("Success-Nat-Enabled", func(t *testing.T) { - deleteAllNodes() - testnode := createTestNode() - gateway.NodeID = testnode.ID.String() - gateway.NatEnabled = "yes" - - node, err := logic.CreateEgressGateway(gateway) - t.Log(node.EgressGatewayNatEnabled) - assert.Nil(t, err) - }) - t.Run("Success-Nat-Disabled", func(t *testing.T) { - deleteAllNodes() - testnode := createTestNode() - gateway.NodeID = testnode.ID.String() - gateway.NatEnabled = "no" - - node, err := logic.CreateEgressGateway(gateway) - t.Log(node.EgressGatewayNatEnabled) - assert.Nil(t, err) - }) - t.Run("Success", func(t *testing.T) { - var gateway models.EgressGatewayRequest - gateway.Ranges = []string{"10.100.100.0/24"} - gateway.NetID = "skynet" - deleteAllNodes() - testnode := createTestNode() - gateway.NodeID = testnode.ID.String() - - node, err := logic.CreateEgressGateway(gateway) - t.Log(node) - assert.Nil(t, err) - assert.Equal(t, true, node.IsEgressGateway) - assert.Equal(t, gateway.Ranges, node.EgressGatewayRanges) - }) - -} -func TestDeleteEgressGateway(t *testing.T) { - var gateway models.EgressGatewayRequest - deleteAllNetworks() - createNet() - testnode := createTestNode() - gateway.Ranges = []string{"10.100.100.0/24"} - gateway.NetID = "skynet" - gateway.NodeID = testnode.ID.String() - t.Run("Success", func(t *testing.T) { - node, err := logic.CreateEgressGateway(gateway) - assert.Nil(t, err) - assert.Equal(t, true, node.IsEgressGateway) - assert.Equal(t, []string{"10.100.100.0/24"}, node.EgressGatewayRanges) - node, err = logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID) - assert.Nil(t, err) - assert.Equal(t, false, node.IsEgressGateway) - assert.Equal(t, []string([]string{}), node.EgressGatewayRanges) - }) - t.Run("NotGateway", func(t *testing.T) { - node, err := logic.DeleteEgressGateway(gateway.NetID, gateway.NodeID) - assert.Nil(t, err) - assert.Equal(t, false, node.IsEgressGateway) - assert.Equal(t, []string([]string{}), node.EgressGatewayRanges) - }) - t.Run("BadNode", func(t *testing.T) { - node, err := logic.DeleteEgressGateway(gateway.NetID, "01:02:03") - assert.EqualError(t, err, "no result found") - assert.Equal(t, models.Node{}, node) - deleteAllNodes() - }) -} - func TestGetNetworkNodes(t *testing.T) { deleteAllNetworks() createNet() diff --git a/controllers/tags.go b/controllers/tags.go index 5dd1ad9f..fc803e0c 100644 --- a/controllers/tags.go +++ b/controllers/tags.go @@ -89,7 +89,7 @@ func createTag(w http.ResponseWriter, r *http.Request) { Network: req.Network, CreatedBy: user.UserName, ColorCode: req.ColorCode, - CreatedAt: time.Now(), + CreatedAt: time.Now().UTC(), } _, err = logic.GetTag(tag.ID) if err == nil { diff --git a/go.mod b/go.mod index a0ab22c0..dadf0d56 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.9.1 gopkg.in/mail.v2 v2.3.1 + gorm.io/datatypes v1.2.5 gorm.io/driver/postgres v1.5.11 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.26.1 @@ -57,6 +58,7 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-jose/go-jose/v4 v4.0.5 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -71,6 +73,7 @@ require ( github.com/seancfoley/bintree v1.3.1 // indirect github.com/spf13/pflag v1.0.6 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gorm.io/driver/mysql v1.5.6 // indirect ) require ( diff --git a/go.sum b/go.sum index 24f8057e..e4d312d0 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,15 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -79,6 +86,8 @@ github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4 github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/microsoft/go-mssqldb v1.7.2 h1:CHkFJiObW7ItKTJfHo1QX7QBBD1iV+mn1eOyRP3b/PA= +github.com/microsoft/go-mssqldb v1.7.2/go.mod h1:kOvZKUdrhhFQmxLZqbwUV0rHkNkZpthMITIb2Ko1IoA= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -140,9 +149,16 @@ gopkg.in/mail.v2 v2.3.1/go.mod h1:htwXN1Qh09vZJ1NVKxQqHPBaCBbzKhp5GzuJEA4VJWw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I= +gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g= +gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.26.1 h1:ghB2gUI9FkS46luZtn6DLZ0f6ooBJ5IbVej2ENFDjRw= gorm.io/gorm v1.26.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/logic/acls.go b/logic/acls.go index c11443ea..009f802e 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -1,6 +1,7 @@ package logic import ( + "context" "encoding/json" "errors" "fmt" @@ -12,10 +13,23 @@ import ( "github.com/google/uuid" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) +/* +TODO: EGRESS +1. allow only selection of egress ranges in a policy +ranges should be replaced by egress identifier + +2. check logic required for MAC exit node + +3. + +*/ + var ( aclCacheMutex = &sync.RWMutex{} aclCacheMap = make(map[string]models.Acl) @@ -236,10 +250,10 @@ func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]st if currentNode.Network != netID.String() { continue } - if currentNode.IsEgressGateway { // add the egress gateway range(s) to the result - if len(currentNode.EgressGatewayRanges) > 0 { - nodeEgressMap[currentNode.ID.String()] = currentNode.EgressGatewayRanges - for _, egressRangeI := range currentNode.EgressGatewayRanges { + if currentNode.EgressDetails.IsEgressGateway { // add the egress gateway range(s) to the result + if len(currentNode.EgressDetails.EgressGatewayRanges) > 0 { + nodeEgressMap[currentNode.ID.String()] = currentNode.EgressDetails.EgressGatewayRanges + for _, egressRangeI := range currentNode.EgressDetails.EgressGatewayRanges { resultMap[egressRangeI] = struct{}{} } } @@ -257,78 +271,102 @@ func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]st return nodeEgressMap, resultMap, nil } -func checkIfAclTagisValid(t models.AclPolicyTag, netID models.NetworkID, policyType models.AclPolicyType, isSrc bool) bool { +func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) { switch t.ID { case models.NodeTagID: - if policyType == models.UserPolicy && isSrc { - return false + if a.RuleType == models.UserPolicy && isSrc { + return errors.New("user policy source mismatch") } // check if tag is valid _, err := GetTag(models.TagID(t.Value)) if err != nil { - return false + return errors.New("invalid tag " + t.Value) } case models.NodeID: - if policyType == models.UserPolicy && isSrc { - return false + if a.RuleType == models.UserPolicy && isSrc { + return errors.New("user policy source mismatch") } _, nodeErr := GetNodeByID(t.Value) if nodeErr != nil { - _, staticNodeErr := GetExtClient(t.Value, netID.String()) + _, staticNodeErr := GetExtClient(t.Value, a.NetworkID.String()) if staticNodeErr != nil { - return false + return errors.New("invalid node " + t.Value) } } - case models.EgressRange: - if isSrc { - return false + case models.EgressID, models.EgressRange: + e := schema.Egress{ + ID: t.Value, } - // _, rangesMap, err := GetEgressRanges(netID) - // if err != nil { - // return false - // } - // if _, ok := rangesMap[t.Value]; !ok { - // return false - // } + err := e.Get(db.WithContext(context.TODO())) + if err != nil { + return errors.New("invalid egress") + } + if e.IsInetGw { + req := models.InetNodeReq{} + for _, srcI := range a.Src { + if srcI.ID == models.NodeTagID { + nodesMap := GetNodesWithTag(models.TagID(srcI.Value)) + for _, node := range nodesMap { + req.InetNodeClientIDs = append(req.InetNodeClientIDs, node.ID.String()) + } + } else if srcI.ID == models.NodeID { + req.InetNodeClientIDs = append(req.InetNodeClientIDs, srcI.Value) + } + } + if len(e.Nodes) > 0 { + for k := range e.Nodes { + inetNode, err := GetNodeByID(k) + if err != nil { + return errors.New("invalid node " + t.Value) + } + if err = ValidateInetGwReq(inetNode, req, false); err != nil { + return err + } + } + + } + + } + case models.UserAclID: - if policyType == models.DevicePolicy { - return false + if a.RuleType == models.DevicePolicy { + return errors.New("device policy source mismatch") } if !isSrc { - return false + return errors.New("user cannot be added to destination") } _, err := GetUser(t.Value) if err != nil { - return false + return errors.New("invalid user " + t.Value) } case models.UserGroupAclID: - if policyType == models.DevicePolicy { - return false + if a.RuleType == models.DevicePolicy { + return errors.New("device policy source mismatch") } if !isSrc { - return false + return errors.New("user cannot be added to destination") } err := IsGroupValid(models.UserGroupID(t.Value)) if err != nil { - return false + return errors.New("invalid user group " + t.Value) } // check if group belongs to this network - netGrps := GetUserGroupsInNetwork(netID) + netGrps := GetUserGroupsInNetwork(a.NetworkID) if _, ok := netGrps[models.UserGroupID(t.Value)]; !ok { - return false + return errors.New("invalid user group " + t.Value) } default: - return false + return errors.New("invalid policy") } - return true + return nil } // IsAclPolicyValid - validates if acl policy is valid -func IsAclPolicyValid(acl models.Acl) bool { +func IsAclPolicyValid(acl models.Acl) (err error) { //check if src and dst are valid if acl.AllowedDirection != models.TrafficDirectionBi && acl.AllowedDirection != models.TrafficDirectionUni { - return false + return errors.New("invalid traffic direction") } switch acl.RuleType { case models.UserPolicy: @@ -339,8 +377,8 @@ func IsAclPolicyValid(acl models.Acl) bool { continue } // check if user group is valid - if !checkIfAclTagisValid(srcI, acl.NetworkID, acl.RuleType, true) { - return false + if err = checkIfAclTagisValid(acl, srcI, true); err != nil { + return } } for _, dstI := range acl.Dst { @@ -350,8 +388,8 @@ func IsAclPolicyValid(acl models.Acl) bool { } // check if user group is valid - if !checkIfAclTagisValid(dstI, acl.NetworkID, acl.RuleType, false) { - return false + if err = checkIfAclTagisValid(acl, dstI, false); err != nil { + return } } case models.DevicePolicy: @@ -360,8 +398,8 @@ func IsAclPolicyValid(acl models.Acl) bool { continue } // check if user group is valid - if !checkIfAclTagisValid(srcI, acl.NetworkID, acl.RuleType, true) { - return false + if err = checkIfAclTagisValid(acl, srcI, true); err != nil { + return err } } for _, dstI := range acl.Dst { @@ -370,12 +408,26 @@ func IsAclPolicyValid(acl models.Acl) bool { continue } // check if user group is valid - if !checkIfAclTagisValid(dstI, acl.NetworkID, acl.RuleType, false) { - return false + if err = checkIfAclTagisValid(acl, dstI, false); err != nil { + return } } } - return true + return nil +} + +func UniqueAclPolicyTags(tags []models.AclPolicyTag) []models.AclPolicyTag { + seen := make(map[string]bool) + var result []models.AclPolicyTag + + for _, tag := range tags { + key := fmt.Sprintf("%v-%s", tag.ID, tag.Value) + if !seen[key] { + seen[key] = true + result = append(result, tag) + } + } + return result } // UpdateAcl - updates allowed fields on acls and commits to DB @@ -623,6 +675,17 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode continue } dstMap := convAclTagToValueMap(policy.Dst) + for _, dst := range policy.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstMap[nodeID] = struct{}{} + } + } + } + } if _, ok := dstMap["*"]; ok { allowedPolicies = append(allowedPolicies, policy) continue @@ -712,8 +775,20 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool { if !policy.Enabled { continue } + srcMap = convAclTagToValueMap(policy.Src) dstMap = convAclTagToValueMap(policy.Dst) + for _, dst := range policy.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstMap[nodeID] = struct{}{} + } + } + } + } if checkTagGroupPolicy(srcMap, dstMap, node, peer, nodeTags, peerTags) { return true } @@ -975,6 +1050,17 @@ func IsNodeAllowedToCommunicateV1(node, peer models.Node, checkDefaultPolicy boo allowed := false srcMap = convAclTagToValueMap(policy.Src) dstMap = convAclTagToValueMap(policy.Dst) + for _, dst := range policy.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstMap[nodeID] = struct{}{} + } + } + } + } _, srcAll := srcMap["*"] _, dstAll := dstMap["*"] if policy.AllowedDirection == models.TrafficDirectionBi { @@ -1158,7 +1244,7 @@ func getEgressUserRulesForNode(targetnode *models.Node, acls := listUserPolicies(models.NetworkID(targetnode.Network)) var targetNodeTags = make(map[models.TagID]struct{}) targetNodeTags["*"] = struct{}{} - for _, rangeI := range targetnode.EgressGatewayRanges { + for _, rangeI := range targetnode.EgressDetails.EgressGatewayRanges { targetNodeTags[models.TagID(rangeI)] = struct{}{} } for _, acl := range acls { @@ -1166,6 +1252,18 @@ func getEgressUserRulesForNode(targetnode *models.Node, continue } dstTags := convAclTagToValueMap(acl.Dst) + for _, dst := range acl.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstTags[nodeID] = struct{}{} + } + dstTags[e.Range] = struct{}{} + } + } + } _, all := dstTags["*"] addUsers := false if !all { @@ -1225,16 +1323,34 @@ func getEgressUserRulesForNode(targetnode *models.Node, r.IP6List = append(r.IP6List, userNode.StaticNode.AddressIPNet6()) } for _, dstI := range acl.Dst { - if dstI.ID == models.EgressRange { - ip, cidr, err := net.ParseCIDR(dstI.Value) - if err == nil { - if ip.To4() != nil { - r.Dst = append(r.Dst, *cidr) - } else { - r.Dst6 = append(r.Dst6, *cidr) - } - + if dstI.ID == models.EgressID { + e := schema.Egress{ID: dstI.Value} + err := e.Get(db.WithContext(context.TODO())) + if err != nil { + continue } + if e.IsInetGw { + r.Dst = append(r.Dst, net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + }) + r.Dst6 = append(r.Dst6, net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + }) + + } else { + ip, cidr, err := net.ParseCIDR(e.Range) + if err == nil { + if ip.To4() != nil { + r.Dst = append(r.Dst, *cidr) + } else { + r.Dst6 = append(r.Dst6, *cidr) + } + + } + } + } } @@ -1348,7 +1464,7 @@ func getUserAclRulesForNode(targetnode *models.Node, } func checkIfAnyActiveEgressPolicy(targetNode models.Node) bool { - if !targetNode.IsEgressGateway { + if !targetNode.EgressDetails.IsEgressGateway { return false } var targetNodeTags = make(map[models.TagID]struct{}) @@ -1371,8 +1487,20 @@ func checkIfAnyActiveEgressPolicy(targetNode models.Node) bool { } srcTags := convAclTagToValueMap(acl.Src) dstTags := convAclTagToValueMap(acl.Dst) + for _, dst := range acl.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstTags[nodeID] = struct{}{} + } + dstTags[e.Range] = struct{}{} + } + } + } for nodeTag := range targetNodeTags { - if acl.RuleType == models.DevicePolicy { + if acl.RuleType == models.DevicePolicy && acl.AllowedDirection == models.TrafficDirectionBi { if _, ok := srcTags[nodeTag.String()]; ok { return true } @@ -1440,21 +1568,7 @@ func checkIfAnyPolicyisUniDirectional(targetNode models.Node) bool { return false } -func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRule) { - targetnode := *targetnodeI - defer func() { - if !targetnode.IsIngressGateway { - rules = getUserAclRulesForNode(&targetnode, rules) - } - }() - rules = make(map[string]models.AclRule) - var taggedNodes map[models.TagID][]models.Node - if targetnode.IsIngressGateway { - taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false) - } else { - taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true) - } - +func checkIfNodeHasAccessToAllResources(targetnode *models.Node) bool { acls := listDevicePolicies(models.NetworkID(targetnode.Network)) var targetNodeTags = make(map[models.TagID]struct{}) if targetnode.Mutex != nil { @@ -1477,6 +1591,85 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu dstTags := convAclTagToValueMap(acl.Dst) _, srcAll := srcTags["*"] _, dstAll := dstTags["*"] + for nodeTag := range targetNodeTags { + + var existsInSrcTag bool + var existsInDstTag bool + + if _, ok := srcTags[nodeTag.String()]; ok { + existsInSrcTag = true + } + if _, ok := srcTags[targetnode.ID.String()]; ok { + existsInSrcTag = true + } + if _, ok := dstTags[nodeTag.String()]; ok { + existsInDstTag = true + } + if _, ok := dstTags[targetnode.ID.String()]; ok { + existsInDstTag = true + } + if acl.AllowedDirection == models.TrafficDirectionBi { + if existsInSrcTag && dstAll || existsInDstTag && srcAll { + return true + } + } else { + if existsInDstTag && srcAll { + return true + } + } + } + } + return false +} + +func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRule) { + targetnode := *targetnodeI + defer func() { + if !targetnode.IsIngressGateway { + rules = getUserAclRulesForNode(&targetnode, rules) + } + }() + rules = make(map[string]models.AclRule) + var taggedNodes map[models.TagID][]models.Node + if targetnode.IsIngressGateway { + taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false) + } else { + taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true) + } + fmt.Printf("TAGGED NODES: %+v\n", taggedNodes) + acls := listDevicePolicies(models.NetworkID(targetnode.Network)) + var targetNodeTags = make(map[models.TagID]struct{}) + if targetnode.Mutex != nil { + targetnode.Mutex.Lock() + targetNodeTags = maps.Clone(targetnode.Tags) + targetnode.Mutex.Unlock() + } else { + targetNodeTags = maps.Clone(targetnode.Tags) + } + if targetNodeTags == nil { + targetNodeTags = make(map[models.TagID]struct{}) + } + targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{} + targetNodeTags["*"] = struct{}{} + for _, acl := range acls { + if !acl.Enabled { + continue + } + srcTags := convAclTagToValueMap(acl.Src) + dstTags := convAclTagToValueMap(acl.Dst) + for _, dst := range acl.Dst { + if dst.ID == models.EgressID { + e := schema.Egress{ID: dst.Value} + err := e.Get(db.WithContext(context.TODO())) + if err == nil && e.Status { + for nodeID := range e.Nodes { + dstTags[nodeID] = struct{}{} + } + } + } + } + _, srcAll := srcTags["*"] + _, dstAll := dstTags["*"] aclRule := models.AclRule{ ID: acl.ID, AllowedProtocol: acl.Proto, @@ -1502,7 +1695,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu existsInDstTag = true } - if existsInSrcTag && !existsInDstTag { + if existsInSrcTag /* && !existsInDstTag*/ { // get all dst tags for dst := range dstTags { if dst == nodeTag.String() { @@ -1539,7 +1732,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu } } } - if existsInDstTag && !existsInSrcTag { + if existsInDstTag /*&& !existsInSrcTag*/ { // get all src tags for src := range srcTags { if src == nodeTag.String() { @@ -1575,47 +1768,47 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu } } } - if existsInDstTag && existsInSrcTag { - nodes := taggedNodes[nodeTag] - for srcID := range srcTags { - if srcID == targetnode.ID.String() { - continue - } - node, err := GetNodeByID(srcID) - if err == nil { - nodes = append(nodes, node) - } - } - for dstID := range dstTags { - if dstID == targetnode.ID.String() { - continue - } - node, err := GetNodeByID(dstID) - if err == nil { - nodes = append(nodes, node) - } - } - for _, node := range nodes { - if node.ID == targetnode.ID { - continue - } - if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() { - continue - } - if node.Address.IP != nil { - aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) - } - if node.Address6.IP != nil { - aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6()) - } - if node.IsStatic && node.StaticNode.Address != "" { - aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4()) - } - if node.IsStatic && node.StaticNode.Address6 != "" { - aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6()) - } - } - } + // if existsInDstTag && existsInSrcTag { + // nodes := taggedNodes[nodeTag] + // for srcID := range srcTags { + // if srcID == targetnode.ID.String() { + // continue + // } + // node, err := GetNodeByID(srcID) + // if err == nil { + // nodes = append(nodes, node) + // } + // } + // for dstID := range dstTags { + // if dstID == targetnode.ID.String() { + // continue + // } + // node, err := GetNodeByID(dstID) + // if err == nil { + // nodes = append(nodes, node) + // } + // } + // for _, node := range nodes { + // if node.ID == targetnode.ID { + // continue + // } + // if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() { + // continue + // } + // if node.Address.IP != nil { + // aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4()) + // } + // if node.Address6.IP != nil { + // aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6()) + // } + // if node.IsStatic && node.StaticNode.Address != "" { + // aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4()) + // } + // if node.IsStatic && node.StaticNode.Address6 != "" { + // aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6()) + // } + // } + // } } else { _, all := dstTags["*"] if _, ok := dstTags[nodeTag.String()]; ok || all { @@ -1677,9 +1870,23 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR if acl policy has egress route and it is present in target node egress ranges fetch all the nodes in that policy and add rules */ - - for _, rangeI := range targetnode.EgressGatewayRanges { - targetNodeTags[models.TagID(rangeI)] = struct{}{} + egs, _ := (&schema.Egress{Network: targetnode.Network}).ListByNetwork(db.WithContext(context.TODO())) + if len(egs) == 0 { + return + } + for _, egI := range egs { + if !egI.Status { + continue + } + if _, ok := egI.Nodes[targetnode.ID.String()]; ok { + if egI.Range == "*" { + targetNodeTags[models.TagID("0.0.0.0/0")] = struct{}{} + targetNodeTags[models.TagID("::/0")] = struct{}{} + } else { + targetNodeTags[models.TagID(egI.Range)] = struct{}{} + } + targetNodeTags[models.TagID(egI.ID)] = struct{}{} + } } for _, acl := range acls { if !acl.Enabled { @@ -1689,46 +1896,43 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR dstTags := convAclTagToValueMap(acl.Dst) _, srcAll := srcTags["*"] _, dstAll := dstTags["*"] + aclRule := models.AclRule{ + ID: acl.ID, + AllowedProtocol: acl.Proto, + AllowedPorts: acl.Port, + Direction: acl.AllowedDirection, + Allowed: true, + } for nodeTag := range targetNodeTags { - aclRule := models.AclRule{ - ID: acl.ID, - AllowedProtocol: acl.Proto, - AllowedPorts: acl.Port, - Direction: acl.AllowedDirection, - Allowed: true, - } + if nodeTag != "*" { ip, cidr, err := net.ParseCIDR(nodeTag.String()) - if err != nil { - continue + if err == nil { + if ip.To4() != nil { + aclRule.Dst = append(aclRule.Dst, *cidr) + } else { + aclRule.Dst6 = append(aclRule.Dst6, *cidr) + } } - if ip.To4() != nil { - aclRule.Dst = append(aclRule.Dst, *cidr) - } else { - aclRule.Dst6 = append(aclRule.Dst6, *cidr) - } - - } else { - aclRule.Dst = append(aclRule.Dst, net.IPNet{ - IP: net.IPv4zero, // 0.0.0.0 - Mask: net.CIDRMask(0, 32), // /0 means match all IPv4 - }) - aclRule.Dst6 = append(aclRule.Dst6, net.IPNet{ - IP: net.IPv6zero, // :: - Mask: net.CIDRMask(0, 128), // /0 means match all IPv6 - }) } if acl.AllowedDirection == models.TrafficDirectionBi { var existsInSrcTag bool var existsInDstTag bool - if _, ok := srcTags[nodeTag.String()]; ok || srcAll { existsInSrcTag = true } if _, ok := dstTags[nodeTag.String()]; ok || dstAll { existsInDstTag = true } - + if srcAll || dstAll { + if targetnode.NetworkRange.IP != nil { + aclRule.IPList = append(aclRule.IPList, targetnode.NetworkRange) + } + if targetnode.NetworkRange6.IP != nil { + aclRule.IP6List = append(aclRule.IP6List, targetnode.NetworkRange6) + } + break + } if existsInSrcTag && !existsInDstTag { // get all dst tags for dst := range dstTags { @@ -1835,8 +2039,16 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR } } } else { - _, all := dstTags["*"] - if _, ok := dstTags[nodeTag.String()]; ok || all { + if dstAll { + if targetnode.NetworkRange.IP != nil { + aclRule.IPList = append(aclRule.IPList, targetnode.NetworkRange) + } + if targetnode.NetworkRange6.IP != nil { + aclRule.IP6List = append(aclRule.IP6List, targetnode.NetworkRange6) + } + break + } + if _, ok := dstTags[nodeTag.String()]; ok || dstAll { // get all src tags for src := range srcTags { if src == nodeTag.String() { @@ -1864,13 +2076,13 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR } } } - if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 { - aclRule.IPList = UniqueIPNetList(aclRule.IPList) - aclRule.IP6List = UniqueIPNetList(aclRule.IP6List) - rules[acl.ID] = aclRule - } } + if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 { + aclRule.IPList = UniqueIPNetList(aclRule.IPList) + aclRule.IP6List = UniqueIPNetList(aclRule.IP6List) + rules[acl.ID] = aclRule + } } return diff --git a/logic/auth.go b/logic/auth.go index 611cb0f0..b2fa7d0d 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -240,7 +240,7 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) { } // update last login time - result.LastLoginTime = time.Now() + result.LastLoginTime = time.Now().UTC() err = UpsertUser(result) if err != nil { slog.Error("error upserting user", "error", err) diff --git a/logic/egress.go b/logic/egress.go new file mode 100644 index 00000000..6dd5fc1b --- /dev/null +++ b/logic/egress.go @@ -0,0 +1,366 @@ +package logic + +import ( + "context" + "encoding/json" + "errors" + "maps" + "net" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" +) + +func ValidateEgressReq(e *schema.Egress) error { + if e.Network == "" { + return errors.New("network id is empty") + } + _, err := GetNetwork(e.Network) + if err != nil { + return errors.New("failed to get network " + err.Error()) + } + if !e.IsInetGw { + if e.Range == "" { + return errors.New("egress range is empty") + } + _, _, err = net.ParseCIDR(e.Range) + if err != nil { + return errors.New("invalid egress range " + err.Error()) + } + err = ValidateEgressRange(e.Network, []string{e.Range}) + if err != nil { + return errors.New("invalid egress range " + err.Error()) + } + } else { + if len(e.Nodes) > 1 { + return errors.New("can only set one internet routing node") + } + req := models.InetNodeReq{} + + for k := range e.Nodes { + inetNode, err := GetNodeByID(k) + if err != nil { + return errors.New("invalid routing node " + err.Error()) + } + // check if node is acting as egress gw already + GetNodeEgressInfo(&inetNode) + if err := ValidateInetGwReq(inetNode, req, false); err != nil { + return err + } + + } + + } + if len(e.Nodes) != 0 { + for k := range e.Nodes { + _, err := GetNodeByID(k) + if err != nil { + return errors.New("invalid routing node " + err.Error()) + } + } + } + return nil +} + +func GetInetClientsFromAclPolicies(eID string) (inetClientIDs []string) { + e := schema.Egress{ID: eID} + err := e.Get(db.WithContext(context.TODO())) + if err != nil || !e.Status { + return + } + acls, _ := ListAclsByNetwork(models.NetworkID(e.Network)) + for _, acl := range acls { + for _, dstI := range acl.Dst { + if dstI.ID == models.EgressID { + if dstI.Value != eID { + continue + } + for _, srcI := range acl.Src { + if srcI.Value == "*" { + continue + } + if srcI.ID == models.NodeID { + inetClientIDs = append(inetClientIDs, srcI.Value) + } + if srcI.ID == models.NodeTagID { + inetClientIDs = append(inetClientIDs, GetNodeIDsWithTag(models.TagID(srcI.Value))...) + } + } + } + } + } + return +} + +func isNodeUsingInternetGw(node *models.Node) { + host, err := GetHost(node.HostID.String()) + if err != nil { + return + } + if host.IsDefault || node.IsFailOver { + return + } + nodeTags := maps.Clone(node.Tags) + nodeTags[models.TagID(node.ID.String())] = struct{}{} + acls, _ := ListAclsByNetwork(models.NetworkID(node.Network)) + var isUsing bool + for _, acl := range acls { + if !acl.Enabled { + continue + } + srcVal := convAclTagToValueMap(acl.Src) + for _, dstI := range acl.Dst { + if dstI.ID == models.EgressID { + e := schema.Egress{ID: dstI.Value} + err := e.Get(db.WithContext(context.TODO())) + if err != nil || !e.Status { + continue + } + + if e.IsInetGw { + if _, ok := srcVal[node.ID.String()]; ok { + for nodeID := range e.Nodes { + if nodeID == node.ID.String() { + continue + } + node.EgressDetails.InternetGwID = nodeID + isUsing = true + return + } + } + for tagID := range nodeTags { + if _, ok := srcVal[tagID.String()]; ok { + for nodeID := range e.Nodes { + if nodeID == node.ID.String() { + continue + } + node.EgressDetails.InternetGwID = nodeID + isUsing = true + return + } + } + } + } + } + } + } + if !isUsing { + node.EgressDetails.InternetGwID = "" + } +} + +func DoesNodeHaveAccessToEgress(node *models.Node, e *schema.Egress) bool { + nodeTags := maps.Clone(node.Tags) + nodeTags[models.TagID(node.ID.String())] = struct{}{} + if !e.IsInetGw { + nodeTags[models.TagID("*")] = struct{}{} + } + acls, _ := ListAclsByNetwork(models.NetworkID(node.Network)) + if !e.IsInetGw { + defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) + if defaultDevicePolicy.Enabled { + return true + } + } + for _, acl := range acls { + if !acl.Enabled { + continue + } + srcVal := convAclTagToValueMap(acl.Src) + if !e.IsInetGw && acl.AllowedDirection == models.TrafficDirectionBi { + if _, ok := srcVal["*"]; ok { + return true + } + } + for _, dstI := range acl.Dst { + + if !e.IsInetGw && dstI.ID == models.NodeTagID && dstI.Value == "*" { + return true + } + if dstI.ID == models.EgressID && dstI.Value == e.ID { + e := schema.Egress{ID: dstI.Value} + err := e.Get(db.WithContext(context.TODO())) + if err != nil { + continue + } + if node.IsStatic { + if _, ok := srcVal[node.StaticNode.ClientID]; ok { + return true + } + } else { + if _, ok := srcVal[node.ID.String()]; ok { + return true + } + } + + for tagID := range nodeTags { + if _, ok := srcVal[tagID.String()]; ok { + return true + } + } + + } + } + } + return false +} + +func AddEgressInfoToPeerByAccess(node, targetNode *models.Node) { + eli, _ := (&schema.Egress{Network: targetNode.Network}).ListByNetwork(db.WithContext(context.TODO())) + req := models.EgressGatewayRequest{ + NodeID: targetNode.ID.String(), + NetID: targetNode.Network, + } + defer func() { + if targetNode.Mutex != nil { + targetNode.Mutex.Lock() + } + isNodeUsingInternetGw(targetNode) + if targetNode.Mutex != nil { + targetNode.Mutex.Unlock() + } + }() + for _, e := range eli { + if !e.Status || e.Network != targetNode.Network { + continue + } + if !DoesNodeHaveAccessToEgress(node, &e) { + if node.IsRelayed && node.RelayedBy == targetNode.ID.String() { + if !DoesNodeHaveAccessToEgress(targetNode, &e) { + continue + } + } else { + continue + } + + } + if metric, ok := e.Nodes[targetNode.ID.String()]; ok { + if e.IsInetGw { + targetNode.EgressDetails.IsInternetGateway = true + targetNode.EgressDetails.InetNodeReq = models.InetNodeReq{ + InetNodeClientIDs: GetInetClientsFromAclPolicies(e.ID), + } + req.Ranges = append(req.Ranges, "0.0.0.0/0") + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: "0.0.0.0/0", + Nat: true, + RouteMetric: 256, + }) + req.Ranges = append(req.Ranges, "::/0") + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: "::/0", + Nat: true, + RouteMetric: 256, + }) + } else { + m64, err := metric.(json.Number).Int64() + if err != nil { + m64 = 256 + } + m := uint32(m64) + req.Ranges = append(req.Ranges, e.Range) + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: e.Range, + Nat: e.Nat, + RouteMetric: m, + }) + } + + } + } + if targetNode.Mutex != nil { + targetNode.Mutex.Lock() + } + if len(req.Ranges) > 0 { + + targetNode.EgressDetails.IsEgressGateway = true + targetNode.EgressDetails.EgressGatewayRanges = req.Ranges + targetNode.EgressDetails.EgressGatewayRequest = req + + } else { + targetNode.EgressDetails = models.EgressDetails{} + } + if targetNode.Mutex != nil { + targetNode.Mutex.Unlock() + } +} + +func GetNodeEgressInfo(targetNode *models.Node) { + eli, _ := (&schema.Egress{Network: targetNode.Network}).ListByNetwork(db.WithContext(context.TODO())) + req := models.EgressGatewayRequest{ + NodeID: targetNode.ID.String(), + NetID: targetNode.Network, + } + defer func() { + if targetNode.Mutex != nil { + targetNode.Mutex.Lock() + } + isNodeUsingInternetGw(targetNode) + if targetNode.Mutex != nil { + targetNode.Mutex.Unlock() + } + }() + for _, e := range eli { + if !e.Status || e.Network != targetNode.Network { + continue + } + if metric, ok := e.Nodes[targetNode.ID.String()]; ok { + if e.IsInetGw { + targetNode.EgressDetails.IsInternetGateway = true + targetNode.EgressDetails.InetNodeReq = models.InetNodeReq{ + InetNodeClientIDs: GetInetClientsFromAclPolicies(e.ID), + } + req.Ranges = append(req.Ranges, "0.0.0.0/0") + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: "0.0.0.0/0", + Nat: true, + RouteMetric: 256, + }) + req.Ranges = append(req.Ranges, "::/0") + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: "::/0", + Nat: true, + RouteMetric: 256, + }) + } else { + m64, err := metric.(json.Number).Int64() + if err != nil { + m64 = 256 + } + m := uint32(m64) + req.Ranges = append(req.Ranges, e.Range) + req.RangesWithMetric = append(req.RangesWithMetric, models.EgressRangeMetric{ + Network: e.Range, + Nat: e.Nat, + RouteMetric: m, + }) + } + + } + } + if targetNode.Mutex != nil { + targetNode.Mutex.Lock() + } + if len(req.Ranges) > 0 { + targetNode.EgressDetails.IsEgressGateway = true + targetNode.EgressDetails.EgressGatewayRanges = req.Ranges + targetNode.EgressDetails.EgressGatewayRequest = req + } else { + targetNode.EgressDetails = models.EgressDetails{} + } + if targetNode.Mutex != nil { + targetNode.Mutex.Unlock() + } +} + +func RemoveNodeFromEgress(node models.Node) { + egs, _ := (&schema.Egress{}).ListByNetwork(db.WithContext(context.TODO())) + for _, egI := range egs { + if _, ok := egI.Nodes[node.ID.String()]; ok { + delete(egI.Nodes, node.ID.String()) + egI.Update(db.WithContext(context.TODO())) + } + } + +} diff --git a/logic/extpeers.go b/logic/extpeers.go index 42c527ac..4dce7a7b 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -1,6 +1,7 @@ package logic import ( + "context" "encoding/json" "errors" "fmt" @@ -13,9 +14,11 @@ import ( "github.com/goombaio/namegenerator" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -71,13 +74,19 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) { if err != nil { return []string{}, err } + // clientNode := client.ConvertToStaticNode() for _, currentNode := range networkNodes { if currentNode.Network != client.Network { continue } - if currentNode.IsEgressGateway { // add the egress gateway range(s) to the result - if len(currentNode.EgressGatewayRanges) > 0 { - result = append(result, currentNode.EgressGatewayRanges...) + GetNodeEgressInfo(¤tNode) + if currentNode.EgressDetails.IsInternetGateway && client.IngressGatewayID != currentNode.ID.String() { + continue + } + if currentNode.EgressDetails.IsEgressGateway { // add the egress gateway range(s) to the result + fmt.Println("EGRESSS EXTCLEINT: ", currentNode.EgressDetails) + if len(currentNode.EgressDetails.EgressGatewayRanges) > 0 { + result = append(result, currentNode.EgressDetails.EgressGatewayRanges...) } } } @@ -627,7 +636,15 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode // add egress range rules for _, dstI := range policy.Dst { - if dstI.ID == models.EgressRange { + if dstI.ID == models.EgressID { + + e := schema.Egress{ID: dstI.Value} + err := e.Get(db.WithContext(context.TODO())) + if err != nil { + continue + } + dstI.Value = e.Range + ip, cidr, err := net.ParseCIDR(dstI.Value) if err == nil { if ip.To4() != nil { @@ -708,7 +725,15 @@ func getFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] // add egress ranges for _, dstI := range policy.Dst { - if dstI.ID == models.EgressRange { + if dstI.ID == models.EgressID { + + e := schema.Egress{ID: dstI.Value} + err := e.Get(db.WithContext(context.TODO())) + if err != nil { + continue + } + dstI.Value = e.Range + ip, cidr, err := net.ParseCIDR(dstI.Value) if err == nil { if ip.To4() != nil && userNodeI.StaticNode.Address != "" { diff --git a/logic/gateway.go b/logic/gateway.go index b07f332b..a8c96e9a 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -1,6 +1,7 @@ package logic import ( + "context" "errors" "fmt" "slices" @@ -8,14 +9,27 @@ import ( "time" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) // IsInternetGw - checks if node is acting as internet gw func IsInternetGw(node models.Node) bool { - return node.IsInternetGateway + e := schema.Egress{ + Network: node.Network, + } + egList, _ := e.ListByNetwork(db.WithContext(context.TODO())) + for _, egI := range egList { + if egI.IsInetGw { + if _, ok := egI.Nodes[node.ID.String()]; ok { + return true + } + } + } + return false } // GetInternetGateways - gets all the nodes that are internet gateways @@ -26,7 +40,7 @@ func GetInternetGateways() ([]models.Node, error) { } igs := make([]models.Node, 0) for _, node := range nodes { - if node.IsInternetGateway { + if node.EgressDetails.IsInternetGateway { igs = append(igs, node) } } @@ -56,7 +70,7 @@ func GetAllEgresses() ([]models.Node, error) { } egresses := make([]models.Node, 0) for _, node := range nodes { - if node.IsEgressGateway { + if node.EgressDetails.IsEgressGateway { egresses = append(egresses, node) } } @@ -133,11 +147,11 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro if gateway.Ranges == nil { gateway.Ranges = make([]string, 0) } - node.IsEgressGateway = true - node.EgressGatewayRanges = gateway.Ranges - node.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled) + node.EgressDetails.IsEgressGateway = true + node.EgressDetails.EgressGatewayRanges = gateway.Ranges + node.EgressDetails.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled) - node.EgressGatewayRequest = gateway // store entire request for use when preserving the egress gateway + node.EgressDetails.EgressGatewayRequest = gateway // store entire request for use when preserving the egress gateway node.SetLastModified() if err = UpsertNode(&node); err != nil { return models.Node{}, err @@ -156,9 +170,9 @@ func DeleteEgressGateway(network, nodeid string) (models.Node, error) { if err != nil { return models.Node{}, err } - node.IsEgressGateway = false - node.EgressGatewayRanges = []string{} - node.EgressGatewayRequest = models.EgressGatewayRequest{} // remove preserved request as the egress gateway is gone + node.EgressDetails.IsEgressGateway = false + node.EgressDetails.EgressGatewayRanges = []string{} + node.EgressDetails.EgressGatewayRequest = models.EgressGatewayRequest{} // remove preserved request as the egress gateway is gone node.SetLastModified() if err = UpsertNode(&node); err != nil { return models.Node{}, err @@ -191,12 +205,12 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq node.IsIngressGateway = true node.IsGw = true if !servercfg.IsPro { - node.IsInternetGateway = ingress.IsInternetGateway + node.EgressDetails.IsInternetGateway = ingress.IsInternetGateway } node.IngressGatewayRange = network.AddressRange node.IngressGatewayRange6 = network.AddressRange6 node.IngressDNS = ingress.ExtclientDNS - if node.IsInternetGateway && node.IngressDNS == "" { + if node.EgressDetails.IsInternetGateway && node.IngressDNS == "" { node.IngressDNS = "1.1.1.1" } node.IngressPersistentKeepalive = 20 @@ -267,10 +281,10 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error return models.Node{}, removedClients, err } logger.Log(3, "deleting ingress gateway") - node.LastModified = time.Now() + node.LastModified = time.Now().UTC() node.IsIngressGateway = false if !servercfg.IsPro { - node.IsInternetGateway = false + node.EgressDetails.IsInternetGateway = false } delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))) node.IngressGatewayRange = "" diff --git a/logic/hosts.go b/logic/hosts.go index 683bf06c..a66e04c2 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -548,17 +548,29 @@ func GetRelatedHosts(hostID string) []models.Host { // CheckHostPort checks host endpoints to ensures that hosts on the same server // with the same endpoint have different listen ports // in the case of 64535 hosts or more with same endpoint, ports will not be changed -func CheckHostPorts(h *models.Host) { +func CheckHostPorts(h *models.Host) (changed bool) { portsInUse := make(map[int]bool, 0) hosts, err := GetAllHosts() if err != nil { return } + originalPort := h.ListenPort + defer func() { + if originalPort != h.ListenPort { + changed = true + } + }() + if h.EndpointIP == nil { + return + } for _, host := range hosts { if host.ID.String() == h.ID.String() { // skip self continue } + if host.EndpointIP == nil { + continue + } if !host.EndpointIP.Equal(h.EndpointIP) { continue } @@ -566,11 +578,16 @@ func CheckHostPorts(h *models.Host) { } // iterate until port is not found or max iteration is reached for i := 0; portsInUse[h.ListenPort] && i < maxPort-minPort+1; i++ { - h.ListenPort++ + if h.ListenPort == 443 { + h.ListenPort = 51821 + } else { + h.ListenPort++ + } if h.ListenPort > maxPort { h.ListenPort = minPort } } + return } // HostExists - checks if given host already exists diff --git a/logic/jwts.go b/logic/jwts.go index 2dc1cf82..64eea672 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -135,7 +135,7 @@ func GetUserNameFromToken(authtoken string) (username string, err error) { err = errors.New("token revoked") return "", err } - a.LastUsed = time.Now() + a.LastUsed = time.Now().UTC() a.Update(db.WithContext(context.TODO())) } } @@ -179,7 +179,7 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin err = errors.New("token revoked") return "", false, false, err } - a.LastUsed = time.Now() + a.LastUsed = time.Now().UTC() a.Update(db.WithContext(context.TODO())) } } diff --git a/logic/networks.go b/logic/networks.go index be914f0c..43f4485e 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -522,7 +522,6 @@ func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) { var network models.Network network, err := GetParentNetwork(networkName) if err != nil { - fmt.Println("Network Not Found") return add, err } if network.IsIPv6 == "no" { @@ -567,7 +566,6 @@ func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) { var network models.Network network, err := GetParentNetwork(networkName) if err != nil { - fmt.Println("Network Not Found") return add, err } if network.IsIPv6 == "no" { diff --git a/logic/nodes.go b/logic/nodes.go index 98687a4a..b6f00027 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -164,7 +164,7 @@ func UpdateNodeCheckin(node *models.Node) error { if err != nil { return err } - + node.EgressDetails = models.EgressDetails{} err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) if err != nil { return err @@ -183,6 +183,7 @@ func UpsertNode(newNode *models.Node) error { if err != nil { return err } + newNode.EgressDetails = models.EgressDetails{} err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) if err != nil { return err @@ -218,7 +219,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { return err } } - + newNode.EgressDetails = models.EgressDetails{} newNode.SetLastModified() if data, err := json.Marshal(newNode); err != nil { return err @@ -280,21 +281,21 @@ func DeleteNode(node *models.Node, purge bool) error { // unset all the relayed nodes SetRelayedNodes(false, node.ID.String(), node.RelayedNodes) } - if node.InternetGwID != "" { - inetNode, err := GetNodeByID(node.InternetGwID) + if node.EgressDetails.InternetGwID != "" { + inetNode, err := GetNodeByID(node.EgressDetails.InternetGwID) if err == nil { clientNodeIDs := []string{} - for _, inetNodeClientID := range inetNode.InetNodeReq.InetNodeClientIDs { + for _, inetNodeClientID := range inetNode.EgressDetails.InetNodeReq.InetNodeClientIDs { if inetNodeClientID == node.ID.String() { continue } clientNodeIDs = append(clientNodeIDs, inetNodeClientID) } - inetNode.InetNodeReq.InetNodeClientIDs = clientNodeIDs + inetNode.EgressDetails.InetNodeReq.InetNodeClientIDs = clientNodeIDs UpsertNode(&inetNode) } } - if node.IsInternetGateway { + if node.EgressDetails.IsInternetGateway { UnsetInternetGw(node) } if !purge && !alreadyDeleted { @@ -320,8 +321,9 @@ func DeleteNode(node *models.Node, purge bool) error { if err := DissasociateNodeFromHost(node, host); err != nil { return err } - go RemoveNodeFromAclPolicy(*node) + go RemoveNodeFromAclPolicy(*node) + go RemoveNodeFromEgress(*node) return nil } @@ -783,16 +785,16 @@ func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error { return nil } -func ValidateEgressRange(gateway models.EgressGatewayRequest) error { - network, err := GetNetworkSettings(gateway.NetID) +func ValidateEgressRange(netID string, ranges []string) error { + network, err := GetNetworkSettings(netID) if err != nil { - slog.Error("error getting network with netid", "error", gateway.NetID, err.Error) - return errors.New("error getting network with netid: " + gateway.NetID + " " + err.Error()) + slog.Error("error getting network with netid", "error", netID, err.Error) + return errors.New("error getting network with netid: " + netID + " " + err.Error()) } ipv4Net := network.AddressRange ipv6Net := network.AddressRange6 - for _, v := range gateway.Ranges { + for _, v := range ranges { if ipv4Net != "" { if ContainsCIDR(ipv4Net, v) { slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv4Net) @@ -949,6 +951,30 @@ func AddTagMapWithStaticNodesWithUsers(netID models.NetworkID, return tagNodesMap } +func GetNodeIDsWithTag(tagID models.TagID) (ids []string) { + + tag, err := GetTag(tagID) + if err != nil { + return + } + nodes, _ := GetNetworkNodes(tag.Network.String()) + for _, nodeI := range nodes { + if nodeI.Tags == nil { + continue + } + if nodeI.Mutex != nil { + nodeI.Mutex.Lock() + } + if _, ok := nodeI.Tags[tagID]; ok { + ids = append(ids, nodeI.ID.String()) + } + if nodeI.Mutex != nil { + nodeI.Mutex.Unlock() + } + } + return +} + func GetNodesWithTag(tagID models.TagID) map[string]models.Node { nMap := make(map[string]models.Node) tag, err := GetTag(tagID) diff --git a/logic/peers.go b/logic/peers.go index b7d1c452..cd4c6e89 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" + "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic/acls/nodeacls" @@ -47,16 +48,19 @@ var ( } // UnsetInternetGw UnsetInternetGw = func(node *models.Node) { - node.IsInternetGateway = false + node.EgressDetails.IsInternetGateway = false } // SetInternetGw SetInternetGw = func(node *models.Node, req models.InetNodeReq) { - node.IsInternetGateway = true + node.EgressDetails.IsInternetGateway = true } // GetAllowedIpForInetNodeClient GetAllowedIpForInetNodeClient = func(node, peer *models.Node) []net.IPNet { return []net.IPNet{} } + ValidateInetGwReq = func(inetNode models.Node, req models.InetNodeReq, update bool) error { + return nil + } ) // GetHostPeerInfo - fetches required peer info per network @@ -161,26 +165,16 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N } defer func() { if !hostPeerUpdate.FwUpdate.AllowAll { - aclRule := models.AclRule{ - ID: "allowed-network-rules", - AllowedProtocol: models.ALL, - Direction: models.TrafficDirectionBi, - Allowed: true, - } - for _, allowedNet := range hostPeerUpdate.FwUpdate.AllowedNetworks { - if allowedNet.IP.To4() != nil { - aclRule.IPList = append(aclRule.IPList, allowedNet) - } else { - aclRule.IP6List = append(aclRule.IP6List, allowedNet) - } - } - hostPeerUpdate.FwUpdate.AclRules["allowed-network-rules"] = aclRule + hostPeerUpdate.FwUpdate.EgressInfo["allowed-network-rules"] = models.EgressInfo{ - EgressID: "allowed-network-rules", - EgressFwRules: map[string]models.AclRule{ - "allowed-network-rules": aclRule, - }, + EgressID: "allowed-network-rules", + EgressFwRules: make(map[string]models.AclRule), } + for _, aclRule := range hostPeerUpdate.FwUpdate.AllowedNetworks { + hostPeerUpdate.FwUpdate.AclRules[aclRule.ID] = aclRule + hostPeerUpdate.FwUpdate.EgressInfo["allowed-network-rules"].EgressFwRules[aclRule.ID] = aclRule + } + } }() @@ -189,14 +183,17 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N for _, nodeID := range host.Nodes { networkAllowAll := true nodeID := nodeID + if nodeID == uuid.Nil.String() { + continue + } node, err := GetNodeByID(nodeID) if err != nil { continue } - if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE { continue } + GetNodeEgressInfo(&node) hostPeerUpdate = SetDefaultGw(node, hostPeerUpdate) if !hostPeerUpdate.IsInternetGw { hostPeerUpdate.IsInternetGw = IsInternetGw(node) @@ -204,13 +201,22 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy) defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) - if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) || (!checkIfAnyPolicyisUniDirectional(node) && !checkIfAnyActiveEgressPolicy(node)) { - if node.NetworkRange.IP != nil { - hostPeerUpdate.FwUpdate.AllowedNetworks = append(hostPeerUpdate.FwUpdate.AllowedNetworks, node.NetworkRange) + if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) || + (!checkIfAnyPolicyisUniDirectional(node) && !checkIfAnyActiveEgressPolicy(node)) || + checkIfNodeHasAccessToAllResources(&node) { + aclRule := models.AclRule{ + ID: fmt.Sprintf("%s-allowed-network-rules", node.ID.String()), + AllowedProtocol: models.ALL, + Direction: models.TrafficDirectionBi, + Allowed: true, + IPList: []net.IPNet{node.NetworkRange}, + IP6List: []net.IPNet{node.NetworkRange6}, } - if node.NetworkRange6.IP != nil { - hostPeerUpdate.FwUpdate.AllowedNetworks = append(hostPeerUpdate.FwUpdate.AllowedNetworks, node.NetworkRange6) + if !(defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) { + aclRule.Dst = []net.IPNet{node.NetworkRange} + aclRule.Dst6 = []net.IPNet{node.NetworkRange6} } + hostPeerUpdate.FwUpdate.AllowedNetworks = append(hostPeerUpdate.FwUpdate.AllowedNetworks, aclRule) } else { networkAllowAll = false hostPeerUpdate.FwUpdate.AllowAll = false @@ -247,8 +253,9 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N PersistentKeepaliveInterval: &peerHost.PersistentKeepalive, ReplaceAllowedIPs: true, } + AddEgressInfoToPeerByAccess(&node, &peer) _, isFailOverPeer := node.FailOverPeers[peer.ID.String()] - if peer.IsEgressGateway { + if peer.EgressDetails.IsEgressGateway { peerKey := peerHost.PublicKey.String() if isFailOverPeer && peer.FailedOverBy.String() != node.ID.String() { // get relay host @@ -435,7 +442,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N logger.Log(1, "error retrieving external clients:", err.Error()) } } - if node.IsEgressGateway && node.EgressGatewayRequest.NatEnabled == "yes" && len(node.EgressGatewayRequest.Ranges) > 0 { + if node.EgressDetails.IsEgressGateway && len(node.EgressDetails.EgressGatewayRequest.Ranges) > 0 { hostPeerUpdate.FwUpdate.IsEgressGw = true hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()] = models.EgressInfo{ EgressID: node.ID.String(), @@ -449,12 +456,12 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N IP: node.Address6.IP, Mask: getCIDRMaskFromAddr(node.Address6.IP.String()), }, - EgressGWCfg: node.EgressGatewayRequest, + EgressGWCfg: node.EgressDetails.EgressGatewayRequest, EgressFwRules: make(map[string]models.AclRule), } } - if node.IsEgressGateway { + if node.EgressDetails.IsEgressGateway { if !networkAllowAll { egressInfo := hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()] if egressInfo.EgressFwRules == nil { @@ -492,7 +499,6 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N }, } } - } // == post peer calculations == // indicate removal if no allowed IPs were calculated @@ -549,11 +555,11 @@ func GetPeerListenPort(host *models.Host) int { } func filterConflictingEgressRoutes(node, peer models.Node) []string { - egressIPs := slices.Clone(peer.EgressGatewayRanges) - if node.IsEgressGateway { + egressIPs := slices.Clone(peer.EgressDetails.EgressGatewayRanges) + if node.EgressDetails.IsEgressGateway { // filter conflicting addrs nodeEgressMap := make(map[string]struct{}) - for _, rangeI := range node.EgressGatewayRanges { + for _, rangeI := range node.EgressDetails.EgressGatewayRanges { nodeEgressMap[rangeI] = struct{}{} } for i := len(egressIPs) - 1; i >= 0; i-- { @@ -567,11 +573,11 @@ func filterConflictingEgressRoutes(node, peer models.Node) []string { } func filterConflictingEgressRoutesWithMetric(node, peer models.Node) []models.EgressRangeMetric { - egressIPs := slices.Clone(peer.EgressGatewayRequest.RangesWithMetric) - if node.IsEgressGateway { + egressIPs := slices.Clone(peer.EgressDetails.EgressGatewayRequest.RangesWithMetric) + if node.EgressDetails.IsEgressGateway { // filter conflicting addrs nodeEgressMap := make(map[string]struct{}) - for _, rangeI := range node.EgressGatewayRanges { + for _, rangeI := range node.EgressDetails.EgressGatewayRanges { nodeEgressMap[rangeI] = struct{}{} } for i := len(egressIPs) - 1; i >= 0; i-- { @@ -588,13 +594,13 @@ func filterConflictingEgressRoutesWithMetric(node, peer models.Node) []models.Eg func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet { var allowedips []net.IPNet allowedips = getNodeAllowedIPs(peer, node) - if peer.IsInternetGateway && node.InternetGwID == peer.ID.String() { + if peer.EgressDetails.IsInternetGateway && node.EgressDetails.InternetGwID == peer.ID.String() { allowedips = append(allowedips, GetAllowedIpForInetNodeClient(node, peer)...) return allowedips } if node.IsRelayed && node.RelayedBy == peer.ID.String() { allowedips = append(allowedips, GetAllowedIpsForRelayed(node, peer)...) - if peer.InternetGwID != "" { + if peer.EgressDetails.InternetGwID != "" { return allowedips } } @@ -623,11 +629,11 @@ func GetEgressIPs(peer *models.Node) []net.IPNet { // check for internet gateway internetGateway := false - if slices.Contains(peer.EgressGatewayRanges, "0.0.0.0/0") || slices.Contains(peer.EgressGatewayRanges, "::/0") { + if slices.Contains(peer.EgressDetails.EgressGatewayRanges, "0.0.0.0/0") || slices.Contains(peer.EgressDetails.EgressGatewayRanges, "::/0") { internetGateway = true } allowedips := []net.IPNet{} - for _, iprange := range peer.EgressGatewayRanges { // go through each cidr for egress gateway + for _, iprange := range peer.EgressDetails.EgressGatewayRanges { // go through each cidr for egress gateway _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr if err != nil { logger.Log(1, "could not parse gateway IP range. Not adding ", iprange) @@ -669,13 +675,13 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet { allowedips = append(allowedips, allowed) } // handle egress gateway peers - if peer.IsEgressGateway { + if peer.EgressDetails.IsEgressGateway { // hasGateway = true egressIPs := GetEgressIPs(peer) - if node.IsEgressGateway { + if node.EgressDetails.IsEgressGateway { // filter conflicting addrs nodeEgressMap := make(map[string]struct{}) - for _, rangeI := range node.EgressGatewayRanges { + for _, rangeI := range node.EgressDetails.EgressGatewayRanges { nodeEgressMap[rangeI] = struct{}{} } for i := len(egressIPs) - 1; i >= 0; i-- { diff --git a/logic/relay.go b/logic/relay.go index 94262cc5..66312019 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -114,13 +114,14 @@ func ValidateRelay(relay models.RelayRequest, update bool) error { if err != nil { return err } + GetNodeEgressInfo(&relayedNode) if relayedNode.IsIngressGateway { return errors.New("cannot relay an ingress gateway (" + relayedNodeID + ")") } - if relayedNode.IsInternetGateway { + if relayedNode.EgressDetails.IsInternetGateway { return errors.New("cannot relay an internet gateway (" + relayedNodeID + ")") } - if relayedNode.InternetGwID != "" && relayedNode.InternetGwID != relay.NodeID { + if relayedNode.EgressDetails.InternetGwID != "" && relayedNode.EgressDetails.InternetGwID != relay.NodeID { return errors.New("cannot relay an internet client (" + relayedNodeID + ")") } if relayedNode.IsFailOver { @@ -193,8 +194,9 @@ func RelayedAllowedIPs(peer, node *models.Node) []net.IPNet { if err != nil { continue } + GetNodeEgressInfo(&relayedNode) allowed := getRelayedAddresses(relayedNodeID) - if relayedNode.IsEgressGateway { + if relayedNode.EgressDetails.IsEgressGateway { allowed = append(allowed, GetEgressIPs(&relayedNode)...) } allowedIPs = append(allowedIPs, allowed...) @@ -208,7 +210,7 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe logger.Log(0, "RelayedByRelay called with invalid parameters") return } - if relay.InternetGwID != "" { + if relay.EgressDetails.InternetGwID != "" { return GetAllowedIpForInetNodeClient(relayed, relay) } peers, err := GetNetworkNodes(relay.Network) diff --git a/logic/tags.go b/logic/tags.go index c48bb6bd..655f48fd 100644 --- a/logic/tags.go +++ b/logic/tags.go @@ -290,7 +290,7 @@ func CreateDefaultTags(netID models.NetworkID) { TagName: models.GwTagName, Network: netID, CreatedBy: "auto", - CreatedAt: time.Now(), + CreatedAt: time.Now().UTC(), } _, err := GetTag(tag.ID) if err == nil { diff --git a/logic/wireguard.go b/logic/wireguard.go index 0828ff6f..778a9e5b 100644 --- a/logic/wireguard.go +++ b/logic/wireguard.go @@ -9,24 +9,10 @@ func IfaceDelta(currentNode *models.Node, newNode *models.Node) bool { // single comparison statements if newNode.Address.String() != currentNode.Address.String() || newNode.Address6.String() != currentNode.Address6.String() || - newNode.IsEgressGateway != currentNode.IsEgressGateway || - newNode.IsIngressGateway != currentNode.IsIngressGateway || newNode.IsRelay != currentNode.IsRelay || - newNode.DNSOn != currentNode.DNSOn || newNode.Connected != currentNode.Connected { return true } - // multi-comparison statements - if newNode.IsEgressGateway { - if len(currentNode.EgressGatewayRanges) != len(newNode.EgressGatewayRanges) { - return true - } - for _, address := range newNode.EgressGatewayRanges { - if !StringSliceContains(currentNode.EgressGatewayRanges, address) { - return true - } - } - } if newNode.IsRelay { if len(currentNode.RelayedNodes) != len(newNode.RelayedNodes) { return true diff --git a/migrate/migrate.go b/migrate/migrate.go index 1b1f8b65..484aa879 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -1,20 +1,24 @@ package migrate import ( + "context" "encoding/json" "fmt" "log" "time" "golang.org/x/exp/slog" + "gorm.io/datatypes" "github.com/google/uuid" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) @@ -30,6 +34,7 @@ func Run() { updateNodes() updateAcls() migrateToGws() + migrateToEgressV1() } func assignSuperAdmin() { @@ -500,6 +505,218 @@ func migrateToGws() { } } +func migrateToEgressV1() { + nodes, _ := logic.GetAllNodes() + user, err := logic.GetSuperAdmin() + if err != nil { + return + } + for _, node := range nodes { + if node.IsEgressGateway { + egressHost, err := logic.GetHost(node.HostID.String()) + if err != nil { + continue + } + for _, rangeI := range node.EgressGatewayRequest.Ranges { + e := schema.Egress{ + ID: uuid.New().String(), + Name: fmt.Sprintf("%s egress", egressHost.Name), + Description: "", + Network: node.Network, + Nodes: datatypes.JSONMap{ + node.ID.String(): 256, + }, + Tags: make(datatypes.JSONMap), + Range: rangeI, + Nat: node.EgressGatewayRequest.NatEnabled == "yes", + Status: true, + CreatedBy: user.UserName, + CreatedAt: time.Now().UTC(), + } + err = e.Create(db.WithContext(context.TODO())) + if err == nil { + node.IsEgressGateway = false + node.EgressGatewayRequest = models.EgressGatewayRequest{} + node.EgressGatewayNatEnabled = false + node.EgressGatewayRanges = []string{} + logic.UpsertNode(&node) + acl := models.Acl{ + ID: uuid.New().String(), + Name: "egress node policy", + MetaData: "", + Default: false, + ServiceType: models.Any, + NetworkID: models.NetworkID(node.Network), + Proto: models.ALL, + RuleType: models.DevicePolicy, + Src: []models.AclPolicyTag{ + + { + ID: models.NodeTagID, + Value: "*", + }, + }, + Dst: []models.AclPolicyTag{ + { + ID: models.EgressID, + Value: e.ID, + }, + }, + + AllowedDirection: models.TrafficDirectionUni, + Enabled: true, + CreatedBy: "auto", + CreatedAt: time.Now().UTC(), + } + logic.InsertAcl(acl) + acl = models.Acl{ + ID: uuid.New().String(), + Name: "egress node policy", + MetaData: "", + Default: false, + ServiceType: models.Any, + NetworkID: models.NetworkID(node.Network), + Proto: models.ALL, + RuleType: models.UserPolicy, + Src: []models.AclPolicyTag{ + + { + ID: models.UserGroupAclID, + Value: "*", + }, + }, + Dst: []models.AclPolicyTag{ + { + ID: models.EgressID, + Value: e.ID, + }, + }, + + AllowedDirection: models.TrafficDirectionUni, + Enabled: true, + CreatedBy: "auto", + CreatedAt: time.Now().UTC(), + } + logic.InsertAcl(acl) + } + + } + + } + + if node.IsInternetGateway { + inetHost, err := logic.GetHost(node.HostID.String()) + if err != nil { + continue + } + e := schema.Egress{ + ID: uuid.New().String(), + Name: fmt.Sprintf("%s inet gw", inetHost.Name), + Description: "add description", + Network: node.Network, + Nodes: datatypes.JSONMap{ + node.ID.String(): 256, + }, + Tags: make(datatypes.JSONMap), + Range: "", + IsInetGw: true, + Nat: node.EgressGatewayRequest.NatEnabled == "yes", + Status: true, + CreatedBy: user.UserName, + CreatedAt: time.Now().UTC(), + } + err = e.Create(db.WithContext(context.TODO())) + if err == nil { + node.IsEgressGateway = false + node.EgressGatewayRequest = models.EgressGatewayRequest{} + node.EgressGatewayNatEnabled = false + node.EgressGatewayRanges = []string{} + node.IsInternetGateway = false + src := []models.AclPolicyTag{} + for _, inetClientID := range node.InetNodeReq.InetNodeClientIDs { + _, err := logic.GetNodeByID(inetClientID) + if err == nil { + src = append(src, models.AclPolicyTag{ + ID: models.NodeID, + Value: inetClientID, + }) + } + } + acl := models.Acl{ + ID: uuid.New().String(), + Name: "exit node policy", + MetaData: "all traffic on source nodes will pass through the destination node in the policy", + Default: false, + ServiceType: models.Any, + NetworkID: models.NetworkID(node.Network), + Proto: models.ALL, + RuleType: models.DevicePolicy, + Src: src, + Dst: []models.AclPolicyTag{ + { + ID: models.EgressID, + Value: e.ID, + }, + }, + + AllowedDirection: models.TrafficDirectionBi, + Enabled: true, + CreatedBy: "auto", + CreatedAt: time.Now().UTC(), + } + logic.InsertAcl(acl) + + acl = models.Acl{ + ID: uuid.New().String(), + Name: "exit node policy", + MetaData: "all traffic on source nodes will pass through the destination node in the policy", + Default: false, + ServiceType: models.Any, + NetworkID: models.NetworkID(node.Network), + Proto: models.ALL, + RuleType: models.UserPolicy, + Src: []models.AclPolicyTag{ + { + ID: models.UserGroupAclID, + Value: fmt.Sprintf("%s-%s-grp", node.Network, models.NetworkAdmin), + }, + { + ID: models.UserGroupAclID, + Value: fmt.Sprintf("global-%s-grp", models.NetworkAdmin), + }, + { + ID: models.UserGroupAclID, + Value: fmt.Sprintf("%s-%s-grp", node.Network, models.NetworkUser), + }, + { + ID: models.UserGroupAclID, + Value: fmt.Sprintf("global-%s-grp", models.NetworkUser), + }, + }, + Dst: []models.AclPolicyTag{ + { + ID: models.EgressID, + Value: e.ID, + }, + }, + + AllowedDirection: models.TrafficDirectionBi, + Enabled: true, + CreatedBy: "auto", + CreatedAt: time.Now().UTC(), + } + logic.InsertAcl(acl) + node.InetNodeReq = models.InetNodeReq{} + logic.UpsertNode(&node) + } + } + if node.InternetGwID != "" { + node.InternetGwID = "" + logic.UpsertNode(&node) + } + } +} + func settings() { _, err := database.FetchRecords(database.SERVER_SETTINGS) if database.IsEmptyRecord(err) { diff --git a/models/accessToken.go b/models/accessToken.go new file mode 100644 index 00000000..cfa5c042 --- /dev/null +++ b/models/accessToken.go @@ -0,0 +1,60 @@ +package models + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" +) + +// accessTokenTableName - access tokens table +const accessTokenTableName = "user_access_tokens" + +// UserAccessToken - token used to access netmaker +type UserAccessToken struct { + ID string `gorm:"id,primary_key" json:"id"` + Name string `gorm:"name" json:"name"` + UserName string `gorm:"user_name" json:"user_name"` + ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"` + LastUsed time.Time `gorm:"last_used" json:"last_used"` + CreatedBy string `gorm:"created_by" json:"created_by"` + CreatedAt time.Time `gorm:"created_at" json:"created_at"` +} + +func (a *UserAccessToken) Table() string { + return accessTokenTableName +} + +func (a *UserAccessToken) Get() error { + return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error +} + +func (a *UserAccessToken) Update() error { + return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error +} + +func (a *UserAccessToken) Create() error { + return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error +} + +func (a *UserAccessToken) List() (ats []UserAccessToken, err error) { + err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error + return +} + +func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) { + db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats) + if ats == nil { + ats = []UserAccessToken{} + } + return +} + +func (a *UserAccessToken) Delete() error { + return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error +} + +func (a *UserAccessToken) DeleteAllUserTokens() error { + return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error + +} diff --git a/models/acl.go b/models/acl.go index 18d7a3d0..adef6c46 100644 --- a/models/acl.go +++ b/models/acl.go @@ -60,6 +60,7 @@ const ( NodeTagID AclGroupType = "tag" NodeID AclGroupType = "device" EgressRange AclGroupType = "egress-range" + EgressID AclGroupType = "egress-id" NetmakerIPAclID AclGroupType = "ip" NetmakerSubNetRangeAClID AclGroupType = "ipset" ) diff --git a/models/api_node.go b/models/api_node.go index 4cf0af3c..58934c31 100644 --- a/models/api_node.go +++ b/models/api_node.go @@ -79,21 +79,16 @@ func (a *ApiNode) ConvertToServerNode(currentNode *Node) *Node { convertedNode.PendingDelete = a.PendingDelete convertedNode.FailedOverBy = currentNode.FailedOverBy convertedNode.FailOverPeers = currentNode.FailOverPeers - convertedNode.IsEgressGateway = a.IsEgressGateway convertedNode.IsIngressGateway = a.IsIngressGateway - // prevents user from changing ranges, must delete and recreate - convertedNode.EgressGatewayRanges = currentNode.EgressGatewayRanges convertedNode.IngressGatewayRange = currentNode.IngressGatewayRange convertedNode.IngressGatewayRange6 = currentNode.IngressGatewayRange6 convertedNode.DNSOn = a.DNSOn convertedNode.IngressDNS = a.IngressDns convertedNode.IngressPersistentKeepalive = a.IngressPersistentKeepalive convertedNode.IngressMTU = a.IngressMTU - convertedNode.IsInternetGateway = a.IsInternetGateway - convertedNode.EgressGatewayRequest = currentNode.EgressGatewayRequest - convertedNode.EgressGatewayNatEnabled = currentNode.EgressGatewayNatEnabled - convertedNode.InternetGwID = currentNode.InternetGwID - convertedNode.InetNodeReq = currentNode.InetNodeReq + convertedNode.EgressDetails.IsInternetGateway = a.IsInternetGateway + convertedNode.EgressDetails.InternetGwID = currentNode.EgressDetails.InternetGwID + convertedNode.EgressDetails.InetNodeReq = currentNode.EgressDetails.InetNodeReq convertedNode.RelayedNodes = a.RelayedNodes convertedNode.DefaultACL = a.DefaultACL convertedNode.OwnerID = currentNode.OwnerID @@ -187,11 +182,7 @@ func (nm *Node) ConvertToAPINode() *ApiNode { apiNode.IsRelay = nm.IsRelay apiNode.RelayedBy = nm.RelayedBy apiNode.RelayedNodes = nm.RelayedNodes - apiNode.IsEgressGateway = nm.IsEgressGateway apiNode.IsIngressGateway = nm.IsIngressGateway - apiNode.EgressGatewayRanges = nm.EgressGatewayRanges - apiNode.EgressGatewayRangesWithMetric = nm.EgressGatewayRequest.RangesWithMetric - apiNode.EgressGatewayNatEnabled = nm.EgressGatewayNatEnabled apiNode.DNSOn = nm.DNSOn apiNode.IngressDns = nm.IngressDNS apiNode.IngressPersistentKeepalive = nm.IngressPersistentKeepalive @@ -200,9 +191,9 @@ func (nm *Node) ConvertToAPINode() *ApiNode { apiNode.Connected = nm.Connected apiNode.PendingDelete = nm.PendingDelete apiNode.DefaultACL = nm.DefaultACL - apiNode.IsInternetGateway = nm.IsInternetGateway - apiNode.InternetGwID = nm.InternetGwID - apiNode.InetNodeReq = nm.InetNodeReq + apiNode.IsInternetGateway = nm.EgressDetails.IsInternetGateway + apiNode.InternetGwID = nm.EgressDetails.InternetGwID + apiNode.InetNodeReq = nm.EgressDetails.InetNodeReq apiNode.IsFailOver = nm.IsFailOver apiNode.FailOverPeers = nm.FailOverPeers apiNode.FailedOverBy = nm.FailedOverBy diff --git a/models/egress.go b/models/egress.go new file mode 100644 index 00000000..02295c8c --- /dev/null +++ b/models/egress.go @@ -0,0 +1,14 @@ +package models + +type EgressReq struct { + ID string `json:"id"` + Name string `json:"name"` + Network string `json:"network"` + Description string `json:"description"` + Nodes map[string]int `json:"nodes"` + Tags []string `json:"tags"` + Range string `json:"range"` + Nat bool `json:"nat"` + Status bool `json:"status"` + IsInetGw bool `json:"is_internet_gateway"` +} diff --git a/models/mqtt.go b/models/mqtt.go index 2aca9259..25cafb28 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -107,7 +107,7 @@ type KeyUpdate struct { // FwUpdate - struct for firewall updates type FwUpdate struct { AllowAll bool `json:"allow_all"` - AllowedNetworks []net.IPNet `json:"networks"` + AllowedNetworks []AclRule `json:"networks"` IsEgressGw bool `json:"is_egress_gw"` IsIngressGw bool `json:"is_ingress_gw"` EgressInfo map[string]EgressInfo `json:"egress_info"` diff --git a/models/node.go b/models/node.go index 72dc9ea5..378fccab 100644 --- a/models/node.go +++ b/models/node.go @@ -109,7 +109,7 @@ type Node struct { DefaultACL string `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"` OwnerID string `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"` IsFailOver bool `json:"is_fail_over" yaml:"is_fail_over"` - FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` + FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"` IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` @@ -121,6 +121,16 @@ type Node struct { StaticNode ExtClient `json:"static_node"` Status NodeStatus `json:"node_status"` Mutex *sync.Mutex `json:"-"` + EgressDetails EgressDetails `json:"-"` +} +type EgressDetails struct { + EgressGatewayNatEnabled bool + EgressGatewayRequest EgressGatewayRequest + IsEgressGateway bool + EgressGatewayRanges []string + IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` + InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` + InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` } // LegacyNode - legacy struct for node model @@ -377,17 +387,17 @@ func (node *LegacyNode) SetIsStaticDefault() { // Node.SetLastModified - set last modified initial time func (node *Node) SetLastModified() { - node.LastModified = time.Now() + node.LastModified = time.Now().UTC() } // Node.SetLastCheckIn - set checkin time of node func (node *Node) SetLastCheckIn() { - node.LastCheckIn = time.Now() + node.LastCheckIn = time.Now().UTC() } // Node.SetLastPeerUpdate - sets last peer update time func (node *Node) SetLastPeerUpdate() { - node.LastPeerUpdate = time.Now() + node.LastPeerUpdate = time.Now().UTC() } // Node.SetExpirationDateTime - sets node expiry time @@ -442,15 +452,9 @@ func (newNode *Node) Fill( if newNode.Network == "" { newNode.Network = currentNode.Network } - if newNode.IsEgressGateway != currentNode.IsEgressGateway { - newNode.IsEgressGateway = currentNode.IsEgressGateway - } if newNode.IsIngressGateway != currentNode.IsIngressGateway { newNode.IsIngressGateway = currentNode.IsIngressGateway } - if newNode.EgressGatewayRanges == nil { - newNode.EgressGatewayRanges = currentNode.EgressGatewayRanges - } if newNode.IngressGatewayRange == "" { newNode.IngressGatewayRange = currentNode.IngressGatewayRange } @@ -567,7 +571,6 @@ func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) { } } node.Action = ln.Action - node.IsEgressGateway = parseBool(ln.IsEgressGateway) node.IsIngressGateway = parseBool(ln.IsIngressGateway) node.DNSOn = parseBool(ln.DNSOn) @@ -601,7 +604,6 @@ func (n *Node) Legacy(h *Host, s *ServerConfig, net *Network) *LegacyNode { //l.IsRelay = formatBool(n.IsRelay) //l.IsDocker = formatBool(n.IsDocker) //l.IsK8S = formatBool(n.IsK8S) - l.IsEgressGateway = formatBool(n.IsEgressGateway) l.IsIngressGateway = formatBool(n.IsIngressGateway) //l.EgressGatewayRanges = n.EgressGatewayRanges //l.EgressGatewayNatEnabled = n.EgressGatewayNatEnabled diff --git a/models/structs.go b/models/structs.go index b6ece02a..09f3cf52 100644 --- a/models/structs.go +++ b/models/structs.go @@ -156,6 +156,7 @@ type ExtPeersResponse struct { type EgressRangeMetric struct { Network string `json:"network"` RouteMetric uint32 `json:"route_metric"` // preffered range 1-999 + Nat bool `json:"nat"` } // EgressGatewayRequest - egress gateway request diff --git a/pro/controllers/failover.go b/pro/controllers/failover.go index 13a9df30..ef5f2a17 100644 --- a/pro/controllers/failover.go +++ b/pro/controllers/failover.go @@ -205,6 +205,8 @@ func failOverME(w http.ResponseWriter, r *http.Request) { ) return } + logic.GetNodeEgressInfo(&node) + logic.GetNodeEgressInfo(&peerNode) if peerNode.IsFailOver { logic.ReturnErrorResponse( w, @@ -245,7 +247,7 @@ func failOverME(w http.ResponseWriter, r *http.Request) { ) return } - if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() { + if node.EgressDetails.IsInternetGateway && peerNode.EgressDetails.InternetGwID == node.ID.String() { logic.ReturnErrorResponse( w, r, @@ -256,7 +258,7 @@ func failOverME(w http.ResponseWriter, r *http.Request) { ) return } - if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() { + if node.EgressDetails.InternetGwID != "" && node.EgressDetails.InternetGwID == peerNode.ID.String() { logic.ReturnErrorResponse( w, r, @@ -349,6 +351,8 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { ) return } + logic.GetNodeEgressInfo(&node) + logic.GetNodeEgressInfo(&peerNode) if peerNode.IsFailOver { logic.ReturnErrorResponse( w, @@ -389,7 +393,18 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { ) return } - if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() { + if node.EgressDetails.InternetGwID != "" || peerNode.EgressDetails.InternetGwID != "" { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError( + errors.New("node using a internet gw by the peer node"), + "badrequest", + ), + ) + return + } + if node.EgressDetails.IsInternetGateway && peerNode.EgressDetails.InternetGwID == node.ID.String() { logic.ReturnErrorResponse( w, r, @@ -400,7 +415,7 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { ) return } - if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() { + if node.EgressDetails.InternetGwID != "" && node.EgressDetails.InternetGwID == peerNode.ID.String() { logic.ReturnErrorResponse( w, r, @@ -411,6 +426,17 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { ) return } + if ok := logic.IsPeerAllowed(node, peerNode, true); !ok { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError( + errors.New("peers are not allowed to communicate"), + "badrequest", + ), + ) + return + } err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode) if err != nil { diff --git a/pro/controllers/inet_gws.go b/pro/controllers/inet_gws.go index d1cd8fd8..b00b9d0c 100644 --- a/pro/controllers/inet_gws.go +++ b/pro/controllers/inet_gws.go @@ -44,7 +44,7 @@ func createInternetGw(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - if node.IsInternetGateway { + if node.EgressDetails.IsInternetGateway { logic.ReturnSuccessResponse(w, r, "node is already acting as internet gateway") return } @@ -132,7 +132,7 @@ func updateInternetGw(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - if !node.IsInternetGateway { + if !node.EgressDetails.IsInternetGateway { logic.ReturnErrorResponse( w, r, diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 9d34d3a8..91cf6579 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -942,7 +942,7 @@ func getUserRemoteAccessNetworkGateways(w http.ResponseWriter, r *http.Request) GwID: node.ID.String(), GWName: host.Name, Network: node.Network, - IsInternetGateway: node.IsInternetGateway, + IsInternetGateway: node.EgressDetails.IsInternetGateway, Metadata: node.Metadata, }) @@ -1069,7 +1069,7 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) { Network: node.Network, GwClient: userConf, Connected: true, - IsInternetGateway: node.IsInternetGateway, + IsInternetGateway: node.EgressDetails.IsInternetGateway, GwPeerPublicKey: host.PublicKey.String(), GwListenPort: logic.GetPeerListenPort(host), Metadata: node.Metadata, @@ -1161,7 +1161,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { Network: node.Network, GwClient: extClient, Connected: true, - IsInternetGateway: node.IsInternetGateway, + IsInternetGateway: node.EgressDetails.IsInternetGateway, GwPeerPublicKey: host.PublicKey.String(), GwListenPort: logic.GetPeerListenPort(host), Metadata: node.Metadata, @@ -1205,7 +1205,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { GwID: node.ID.String(), GWName: host.Name, Network: node.Network, - IsInternetGateway: node.IsInternetGateway, + IsInternetGateway: node.EgressDetails.IsInternetGateway, GwPeerPublicKey: host.PublicKey.String(), GwListenPort: logic.GetPeerListenPort(host), Metadata: node.Metadata, diff --git a/pro/initialize.go b/pro/initialize.go index 85ebeae5..fc57aa99 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -110,6 +110,7 @@ func InitPro() { logic.DeleteMetrics = proLogic.DeleteMetrics logic.GetTrialEndDate = getTrialEndDate logic.SetDefaultGw = proLogic.SetDefaultGw + logic.ValidateInetGwReq = proLogic.ValidateInetGwReq logic.SetDefaultGwForRelayedUpdate = proLogic.SetDefaultGwForRelayedUpdate logic.UnsetInternetGw = proLogic.UnsetInternetGw logic.SetInternetGw = proLogic.SetInternetGw diff --git a/pro/logic/failover.go b/pro/logic/failover.go index d4ac5ff6..7c21ba52 100644 --- a/pro/logic/failover.go +++ b/pro/logic/failover.go @@ -165,6 +165,7 @@ func GetFailOverPeerIps(peer, node *models.Node) []net.IPNet { for failOverpeerID := range node.FailOverPeers { failOverpeer, err := logic.GetNodeByID(failOverpeerID) if err == nil && failOverpeer.FailedOverBy == peer.ID { + logic.GetNodeEgressInfo(&failOverpeer) if failOverpeer.Address.IP != nil { allowed := net.IPNet{ IP: failOverpeer.Address.IP, @@ -179,7 +180,7 @@ func GetFailOverPeerIps(peer, node *models.Node) []net.IPNet { } allowedips = append(allowedips, allowed) } - if failOverpeer.IsEgressGateway { + if failOverpeer.EgressDetails.IsEgressGateway { allowedips = append(allowedips, logic.GetEgressIPs(&failOverpeer)...) } if failOverpeer.IsRelay { @@ -199,7 +200,7 @@ func GetFailOverPeerIps(peer, node *models.Node) []net.IPNet { } allowedips = append(allowedips, allowed) } - if rNode.IsEgressGateway { + if rNode.EgressDetails.IsEgressGateway { allowedips = append(allowedips, logic.GetEgressIPs(&rNode)...) } } diff --git a/pro/logic/nodes.go b/pro/logic/nodes.go index ca8ca94a..930cc72f 100644 --- a/pro/logic/nodes.go +++ b/pro/logic/nodes.go @@ -24,7 +24,7 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool if inetHost.FirewallInUse == models.FIREWALL_NONE { return errors.New("iptables or nftables needs to be installed") } - if inetNode.InternetGwID != "" { + if inetNode.EgressDetails.InternetGwID != "" { return fmt.Errorf("node %s is using a internet gateway already", inetHost.Name) } if inetNode.IsRelayed { @@ -36,22 +36,28 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool if err != nil { return err } + if clientNode.IsFailOver { + return errors.New("failover node cannot be set to use internet gateway") + } clientHost, err := logic.GetHost(clientNode.HostID.String()) if err != nil { return err } + if clientHost.IsDefault { + return errors.New("default host cannot be set to use internet gateway") + } if clientHost.OS != models.OS_Types.Linux && clientHost.OS != models.OS_Types.Windows { return errors.New("can only attach linux or windows machine to a internet gateway") } - if clientNode.IsInternetGateway { + if clientNode.EgressDetails.IsInternetGateway { return fmt.Errorf("node %s acting as internet gateway cannot use another internet gateway", clientHost.Name) } if update { - if clientNode.InternetGwID != "" && clientNode.InternetGwID != inetNode.ID.String() { + if clientNode.EgressDetails.InternetGwID != "" && clientNode.EgressDetails.InternetGwID != inetNode.ID.String() { return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name) } } else { - if clientNode.InternetGwID != "" { + if clientNode.EgressDetails.InternetGwID != "" { return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name) } } @@ -68,7 +74,7 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool if err != nil { continue } - if node.InternetGwID != "" && node.InternetGwID != inetNode.ID.String() { + if node.EgressDetails.InternetGwID != "" && node.EgressDetails.InternetGwID != inetNode.ID.String() { return errors.New("nodes on same host cannot use different internet gateway") } @@ -79,14 +85,14 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool // SetInternetGw - sets the node as internet gw based on flag bool func SetInternetGw(node *models.Node, req models.InetNodeReq) { - node.IsInternetGateway = true - node.InetNodeReq = req + node.EgressDetails.IsInternetGateway = true + node.EgressDetails.InetNodeReq = req for _, clientNodeID := range req.InetNodeClientIDs { clientNode, err := logic.GetNodeByID(clientNodeID) if err != nil { continue } - clientNode.InternetGwID = node.ID.String() + clientNode.EgressDetails.InternetGwID = node.ID.String() logic.UpsertNode(&clientNode) } @@ -99,19 +105,19 @@ func UnsetInternetGw(node *models.Node) { return } for _, clientNode := range nodes { - if node.ID.String() == clientNode.InternetGwID { - clientNode.InternetGwID = "" + if node.ID.String() == clientNode.EgressDetails.InternetGwID { + clientNode.EgressDetails.InternetGwID = "" logic.UpsertNode(&clientNode) } } - node.IsInternetGateway = false - node.InetNodeReq = models.InetNodeReq{} + node.EgressDetails.IsInternetGateway = false + node.EgressDetails.InetNodeReq = models.InetNodeReq{} } func SetDefaultGwForRelayedUpdate(relayed, relay models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate { - if relay.InternetGwID != "" { + if relay.EgressDetails.InternetGwID != "" { relayedHost, err := logic.GetHost(relayed.HostID.String()) if err != nil { return peerUpdate @@ -127,9 +133,9 @@ func SetDefaultGwForRelayedUpdate(relayed, relay models.Node, peerUpdate models. } func SetDefaultGw(node models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate { - if node.InternetGwID != "" { + if node.EgressDetails.InternetGwID != "" { - inetNode, err := logic.GetNodeByID(node.InternetGwID) + inetNode, err := logic.GetNodeByID(node.EgressDetails.InternetGwID) if err != nil { return peerUpdate } diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index ae4bbbb5..0be37520 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -7,7 +7,6 @@ import ( "time" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" @@ -659,30 +658,13 @@ func GetUserRAGNodesV1(user models.User) (gws map[string]models.Node) { func GetUserRAGNodes(user models.User) (gws map[string]models.Node) { gws = make(map[string]models.Node) - userGwAccessScope := GetUserNetworkRolesWithRemoteVPNAccess(user) - logger.Log(3, fmt.Sprintf("User Gw Access Scope: %+v", userGwAccessScope)) - _, allNetAccess := userGwAccessScope["*"] nodes, err := logic.GetAllNodes() if err != nil { return } for _, node := range nodes { - if node.IsIngressGateway && !node.PendingDelete { - if allNetAccess { - gws[node.ID.String()] = node - } else { - gwRsrcMap := userGwAccessScope[models.NetworkID(node.Network)] - scope, ok := gwRsrcMap[models.AllRemoteAccessGwRsrcID] - if !ok { - if scope, ok = gwRsrcMap[models.RsrcID(node.ID.String())]; !ok { - continue - } - } - if scope.VPNaccess { - gws[node.ID.String()] = node - } - - } + if ok, _ := logic.IsUserAllowedToCommunicate(user.UserName, node); ok { + gws[node.ID.String()] = node } } return diff --git a/schema/activity.go b/schema/activity.go new file mode 100644 index 00000000..35c0b55c --- /dev/null +++ b/schema/activity.go @@ -0,0 +1,4 @@ +package schema + +type Activity struct { +} diff --git a/schema/egress.go b/schema/egress.go new file mode 100644 index 00000000..f24d3813 --- /dev/null +++ b/schema/egress.go @@ -0,0 +1,70 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +const egressTable = "egresses" + +type Egress struct { + ID string `gorm:"primaryKey" json:"id"` + Name string `gorm:"name" json:"name"` + Network string `gorm:"network" json:"network"` + Description string `gorm:"description" json:"description"` + Nodes datatypes.JSONMap `gorm:"nodes" json:"nodes"` + Tags datatypes.JSONMap `gorm:"tags" json:"tags"` + Range string `gorm:"range" json:"range"` + Nat bool `gorm:"nat" json:"nat"` + IsInetGw bool `gorm:"is_inet_gw" json:"is_internet_gateway"` + Status bool `gorm:"status" json:"status"` + CreatedBy string `gorm:"created_by" json:"created_by"` + CreatedAt time.Time `gorm:"created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` +} + +func (e *Egress) Table() string { + return egressTable +} + +func (e *Egress) Get(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).First(&e).Where("id = ?", e.ID).Error +} + +func (e *Egress) Update(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(&e).Error +} + +func (e *Egress) UpdateNatStatus(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{ + "nat": e.Nat, + }).Error +} + +func (e *Egress) UpdateINetGwStatus(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{ + "is_inet_gw": e.IsInetGw, + }).Error +} + +func (e *Egress) UpdateEgressStatus(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{ + "status": e.Status, + }).Error +} + +func (e *Egress) Create(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Create(&e).Error +} + +func (e *Egress) ListByNetwork(ctx context.Context) (egs []Egress, err error) { + err = db.FromContext(ctx).Table(e.Table()).Where("network = ?", e.Network).Find(&egs).Error + return +} + +func (e *Egress) Delete(ctx context.Context) error { + return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error +} diff --git a/schema/models.go b/schema/models.go index 5c27a7b2..1ffbd0e3 100644 --- a/schema/models.go +++ b/schema/models.go @@ -4,6 +4,7 @@ package schema func ListModels() []interface{} { return []interface{}{ &Job{}, + &Egress{}, &UserAccessToken{}, } }