From 059b86bc93df685903518fc12acf911c63485a27 Mon Sep 17 00:00:00 2001 From: Vishal Dalwadi Date: Mon, 2 Jun 2025 22:56:38 +0530 Subject: [PATCH] feat(go): restore nodeacls; --- cli/cmd/acl/allow.go | 10 +- cli/cmd/acl/deny.go | 10 +- cli/cmd/acl/list.go | 12 +- cli/functions/acl.go | 10 +- controllers/network.go | 108 ++++++---------- controllers/node_test.go | 55 ++++++++- converters/node_acl.go | 38 ------ logic/acls/common.go | 199 ++++++++++++++++++++++++++++++ logic/acls/nodeacls/modify.go | 102 +++++++++++++++ logic/acls/nodeacls/retrieve.go | 76 ++++++++++++ logic/acls/nodeacls/types.go | 14 +++ logic/{nodeacls => acls}/types.go | 2 +- logic/clients.go | 37 ++---- logic/extpeers.go | 27 ++-- logic/networks.go | 7 +- logic/nodeacls/node_acls.go | 168 ------------------------- logic/nodes.go | 11 +- logic/peers.go | 6 +- logic/relay.go | 4 +- migrate/migrate.go | 89 ++++++------- pro/logic/ext_acls.go | 57 ++++----- schema/models.go | 1 - schema/network_acl.go | 49 -------- serverctl/serverctl.go | 13 +- 24 files changed, 616 insertions(+), 489 deletions(-) delete mode 100644 converters/node_acl.go create mode 100644 logic/acls/common.go create mode 100644 logic/acls/nodeacls/modify.go create mode 100644 logic/acls/nodeacls/retrieve.go create mode 100644 logic/acls/nodeacls/types.go rename logic/{nodeacls => acls}/types.go (97%) delete mode 100644 logic/nodeacls/node_acls.go delete mode 100644 schema/network_acl.go diff --git a/cli/cmd/acl/allow.go b/cli/cmd/acl/allow.go index d31bb508..14bef7e4 100644 --- a/cli/cmd/acl/allow.go +++ b/cli/cmd/acl/allow.go @@ -2,10 +2,10 @@ package acl import ( "fmt" - "github.com/gravitl/netmaker/logic/nodeacls" "log" "github.com/gravitl/netmaker/cli/functions" + "github.com/gravitl/netmaker/logic/acls" "github.com/spf13/cobra" ) @@ -31,16 +31,16 @@ var aclAllowCmd = &cobra.Command{ payload := *res - if _, ok := payload[nodeacls.AclID(fromNodeID)]; !ok { + if _, ok := payload[acls.AclID(fromNodeID)]; !ok { log.Fatalf("Node %s does not exist", fromNodeID) } - if _, ok := payload[nodeacls.AclID(toNodeID)]; !ok { + if _, ok := payload[acls.AclID(toNodeID)]; !ok { log.Fatalf("Node %s does not exist", toNodeID) } // update acls - payload[nodeacls.AclID(fromNodeID)][nodeacls.AclID(toNodeID)] = nodeacls.Allowed - payload[nodeacls.AclID(toNodeID)][nodeacls.AclID(fromNodeID)] = nodeacls.Allowed + payload[acls.AclID(fromNodeID)][acls.AclID(toNodeID)] = acls.Allowed + payload[acls.AclID(toNodeID)][acls.AclID(fromNodeID)] = acls.Allowed functions.UpdateACL(network, &payload) fmt.Println("Success") diff --git a/cli/cmd/acl/deny.go b/cli/cmd/acl/deny.go index 9ff83a89..587e42dc 100644 --- a/cli/cmd/acl/deny.go +++ b/cli/cmd/acl/deny.go @@ -2,10 +2,10 @@ package acl import ( "fmt" - "github.com/gravitl/netmaker/logic/nodeacls" "log" "github.com/gravitl/netmaker/cli/functions" + "github.com/gravitl/netmaker/logic/acls" "github.com/spf13/cobra" ) @@ -31,16 +31,16 @@ var aclDenyCmd = &cobra.Command{ payload := *res - if _, ok := payload[nodeacls.AclID(fromNodeID)]; !ok { + if _, ok := payload[acls.AclID(fromNodeID)]; !ok { log.Fatalf("Node [%s] does not exist", fromNodeID) } - if _, ok := payload[nodeacls.AclID(toNodeID)]; !ok { + if _, ok := payload[acls.AclID(toNodeID)]; !ok { log.Fatalf("Node [%s] does not exist", toNodeID) } // update acls - payload[nodeacls.AclID(fromNodeID)][nodeacls.AclID(toNodeID)] = nodeacls.NotAllowed - payload[nodeacls.AclID(toNodeID)][nodeacls.AclID(fromNodeID)] = nodeacls.NotAllowed + payload[acls.AclID(fromNodeID)][acls.AclID(toNodeID)] = acls.NotAllowed + payload[acls.AclID(toNodeID)][acls.AclID(fromNodeID)] = acls.NotAllowed functions.UpdateACL(network, &payload) fmt.Println("Success") diff --git a/cli/cmd/acl/list.go b/cli/cmd/acl/list.go index 65f588fa..41277444 100644 --- a/cli/cmd/acl/list.go +++ b/cli/cmd/acl/list.go @@ -1,11 +1,11 @@ package acl import ( - "github.com/gravitl/netmaker/logic/nodeacls" "os" "github.com/gravitl/netmaker/cli/cmd/commons" "github.com/gravitl/netmaker/cli/functions" + "github.com/gravitl/netmaker/logic/acls" "github.com/guumaster/tablewriter" "github.com/spf13/cobra" ) @@ -16,7 +16,7 @@ var aclListCmd = &cobra.Command{ Short: "List all ACLs associated with a network", Long: `List all ACLs associated with a network`, Run: func(cmd *cobra.Command, args []string) { - aclSource := (map[nodeacls.AclID]nodeacls.ACL)(*functions.GetACL(args[0])) + aclSource := (map[acls.AclID]acls.ACL)(*functions.GetACL(args[0])) switch commons.OutputFormat { case commons.JsonOutput: functions.PrettyPrint(aclSource) @@ -24,14 +24,14 @@ var aclListCmd = &cobra.Command{ table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"From", "To", "Status"}) for id, acl := range aclSource { - for k, v := range (map[nodeacls.AclID]byte)(acl) { + for k, v := range (map[acls.AclID]byte)(acl) { row := []string{string(id), string(k)} switch v { - case nodeacls.NotAllowed: + case acls.NotAllowed: row = append(row, "Not Allowed") - case nodeacls.NotPresent: + case acls.NotPresent: row = append(row, "Not Present") - case nodeacls.Allowed: + case acls.Allowed: row = append(row, "Allowed") } table.Append(row) diff --git a/cli/functions/acl.go b/cli/functions/acl.go index 0e3a3ce4..2fe4cf28 100644 --- a/cli/functions/acl.go +++ b/cli/functions/acl.go @@ -2,16 +2,16 @@ package functions import ( "fmt" - "github.com/gravitl/netmaker/logic/nodeacls" + "github.com/gravitl/netmaker/logic/acls" "net/http" ) // GetACL - fetch all ACLs associated with a network -func GetACL(networkName string) *nodeacls.ACLContainer { - return request[nodeacls.ACLContainer](http.MethodGet, fmt.Sprintf("/api/networks/%s/acls", networkName), nil) +func GetACL(networkName string) *acls.ACLContainer { + return request[acls.ACLContainer](http.MethodGet, fmt.Sprintf("/api/networks/%s/acls", networkName), nil) } // UpdateACL - update an ACL -func UpdateACL(networkName string, payload *nodeacls.ACLContainer) *nodeacls.ACLContainer { - return request[nodeacls.ACLContainer](http.MethodPut, fmt.Sprintf("/api/networks/%s/acls/v2", networkName), payload) +func UpdateACL(networkName string, payload *acls.ACLContainer) *acls.ACLContainer { + return request[acls.ACLContainer](http.MethodPut, fmt.Sprintf("/api/networks/%s/acls/v2", networkName), payload) } diff --git a/controllers/network.go b/controllers/network.go index 2a8c24bf..f748b5c5 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -4,10 +4,8 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gravitl/netmaker/converters" - "github.com/gravitl/netmaker/logic/nodeacls" + "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/schema" - "gorm.io/gorm" "net" "net/http" "slices" @@ -170,7 +168,7 @@ func getNetwork(w http.ResponseWriter, r *http.Request) { func updateNetworkACL(w http.ResponseWriter, r *http.Request) { networkID := mux.Vars(r)["networkname"] - var networkACLUpdateRequest nodeacls.ACLContainer + var networkACLUpdateRequest acls.ACLContainer err := json.NewDecoder(r.Body).Decode(&networkACLUpdateRequest) if err != nil { logger.Log(0, fmt.Sprintf("failed to decode network (%s) acl update request: %s", networkID, err.Error())) @@ -178,43 +176,32 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) { return } - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - err = _networkACL.Get(r.Context()) + var networkACL acls.ACLContainer + networkACL, err = networkACL.Get(acls.ContainerID(networkID)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acls: %s", networkID, err.Error())) - - if errors.Is(err, gorm.ErrRecordNotFound) { - err = fmt.Errorf("network (%s) acls not found", networkID) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "notfound")) - } else { - err = fmt.Errorf("failed to get network (%s) acls: %s", networkID, err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - } - - return - } - - _networkACLUpdates := converters.ToSchemaNetworkACL(networkID, networkACLUpdateRequest) - err = _networkACLUpdates.Update(r.Context()) - if err != nil { - logger.Log(0, fmt.Sprintf("failed to update network (%s) acls: %s", networkID, err.Error())) + logger.Log(0, r.Header.Get("user"), + fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", networkID, err)) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logger.Log(1, r.Header.Get("user"), "updated acls for network", networkID) + newNetACL, err := networkACL.Save(acls.ContainerID(networkID)) + if err != nil { + logger.Log(0, r.Header.Get("user"), + fmt.Sprintf("failed to update ACLs for network [%s]: %v", networkID, err)) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + logger.Log(1, r.Header.Get("user"), "updated ACLs for network", networkID) // send peer updates go func() { - err = mq.PublishPeerUpdate(false) - if err != nil { - logger.Log(0, fmt.Sprintf("failed to publish peer update after network (%s) acl updates: %s", networkID, err.Error())) + if err = mq.PublishPeerUpdate(false); err != nil { + logger.Log(0, "failed to publish peer update after ACL update on network:", networkID) } }() - logic.ReturnSuccessJsonResponse(w, r, networkACLUpdateRequest) + logic.ReturnSuccessJsonResponse(w, r, newNetACL) } // @Summary Update a network ACL (Access Control List) @@ -230,7 +217,7 @@ func updateNetworkACL(w http.ResponseWriter, r *http.Request) { func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { networkID := mux.Vars(r)["networkname"] - var networkACLUpdateRequest nodeacls.ACLContainer + var networkACLUpdateRequest acls.ACLContainer err := json.NewDecoder(r.Body).Decode(&networkACLUpdateRequest) if err != nil { logger.Log(0, fmt.Sprintf("failed to decode network (%s) acl update request: %s", networkID, err.Error())) @@ -238,21 +225,12 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { return } - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - err = _networkACL.Get(r.Context()) + var networkACL acls.ACLContainer + networkACL, err = networkACL.Get(acls.ContainerID(networkID)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acls: %s", networkID, err.Error())) - - if errors.Is(err, gorm.ErrRecordNotFound) { - err = fmt.Errorf("network (%s) acls not found", networkID) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "notfound")) - } else { - err = fmt.Errorf("failed to get network (%s) acls: %s", networkID, err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - } - + logger.Log(0, r.Header.Get("user"), + fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", networkID, err)) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -302,7 +280,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { if client.DeniedACLs == nil { client.DeniedACLs = make(map[string]struct{}) } - if acl[nodeacls.AclID(clientID)] == nodeacls.NotAllowed { + if acl[acls.AclID(clientID)] == acls.NotAllowed { client.DeniedACLs[nodeID] = struct{}{} } else { delete(client.DeniedACLs, nodeID) @@ -335,7 +313,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { } } else { nodeId2 := string(id2) - if extClientsMap[clientId].IngressGatewayID == nodeId2 && acl[nodeacls.AclID(nodeId2)] == nodeacls.NotAllowed { + if extClientsMap[clientId].IngressGatewayID == nodeId2 && acl[acls.AclID(nodeId2)] == acls.NotAllowed { assocClientsToDisconnectPerHost[nodesMap[nodeId2].HostID] = append(assocClientsToDisconnectPerHost[nodesMap[nodeId2].HostID], extClientsMap[clientId]) } } @@ -374,13 +352,14 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { } } - _networkACLUpdates := converters.ToSchemaNetworkACL(networkID, networkACLUpdateRequest) - err = _networkACLUpdates.Update(r.Context()) + _, err = networkACLUpdateRequest.Save(acls.ContainerID(networkID)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to update network (%s) acls: %s", networkID, err.Error())) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + logger.Log(0, r.Header.Get("user"), + fmt.Sprintf("failed to update ACLs for network [%s]: %v", networkID, err)) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + logger.Log(1, r.Header.Get("user"), "updated ACLs for network", networkID) logger.Log(1, r.Header.Get("user"), "updated acls for network", networkID) @@ -431,27 +410,22 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) { func getNetworkACL(w http.ResponseWriter, r *http.Request) { networkID := mux.Vars(r)["networkname"] - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - err := _networkACL.Get(r.Context()) + var networkACL acls.ACLContainer + networkACL, err := networkACL.Get(acls.ContainerID(networkID)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acls: %s", networkID, err.Error())) - - if errors.Is(err, gorm.ErrRecordNotFound) { - err = fmt.Errorf("network (%s) acls not found", networkID) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "notfound")) - } else { - err = fmt.Errorf("failed to get network (%s) acls: %s", networkID, err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + if database.IsEmptyRecord(err) { + networkACL = acls.ACLContainer{} + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(networkACL) + return } - + logger.Log(0, r.Header.Get("user"), + fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", networkID, err)) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logger.Log(2, r.Header.Get("user"), "fetched acls for network", networkID) - - logic.ReturnSuccessJsonResponse(w, r, converters.ToACLContainer(*_networkACL)) + logic.ReturnSuccessJsonResponse(w, r, networkACL) } // @Summary Get a network Egress routes diff --git a/controllers/node_test.go b/controllers/node_test.go index ce207972..8996b2f5 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -1,13 +1,14 @@ package controller import ( - "github.com/gravitl/netmaker/logic/nodeacls" "net" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -54,7 +55,42 @@ func TestNodeACLs(t *testing.T) { node2 := createNodeWithParams("", "10.0.0.100/32") logic.AssociateNodeToHost(node1, &linuxHost) logic.AssociateNodeToHost(node2, &linuxHost) - + t.Run("acls not present", func(t *testing.T) { + currentACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(node1.Network)) + assert.Nil(t, err) + assert.NotNil(t, currentACL) + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String())) + assert.Nil(t, err) + assert.NotNil(t, node1ACL) + assert.Equal(t, acls.Allowed, node1ACL[acls.AclID(node2.ID.String())]) + }) + t.Run("node acls exists after creates", func(t *testing.T) { + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String())) + assert.Nil(t, err) + assert.NotNil(t, node1ACL) + node2ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node2.Network), nodeacls.NodeID(node2.ID.String())) + assert.Nil(t, err) + assert.NotNil(t, node2ACL) + assert.Equal(t, acls.Allowed, node2ACL[acls.AclID(node1.ID.String())]) + }) + t.Run("node acls correct after fetch", func(t *testing.T) { + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String())) + assert.Nil(t, err) + assert.Equal(t, acls.Allowed, node1ACL[acls.AclID(node2.ID.String())]) + }) + t.Run("node acls correct after modify", func(t *testing.T) { + node1ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String())) + assert.Nil(t, err) + assert.NotNil(t, node1ACL) + node2ACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(node2.Network), nodeacls.NodeID(node2.ID.String())) + assert.Nil(t, err) + assert.NotNil(t, node2ACL) + currentACL, err := nodeacls.DisallowNodes(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String()), nodeacls.NodeID(node2.ID.String())) + assert.Nil(t, err) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node2.ID.String())]) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node2.ID.String())][acls.AclID(node1.ID.String())]) + currentACL.Save(acls.ContainerID(node1.Network)) + }) t.Run("node acls correct after add new node not allowed", func(t *testing.T) { node3 := createNodeWithParams("", "10.0.0.100/32") createNodeHosts() @@ -65,10 +101,23 @@ func TestNodeACLs(t *testing.T) { assert.Nil(t, e) err := logic.AssociateNodeToHost(node3, &linuxHost) assert.Nil(t, err) + currentACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(node3.Network)) + assert.Nil(t, err) + assert.NotNil(t, currentACL) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node3.ID.String())]) + nodeACL, err := nodeacls.CreateNodeACL(nodeacls.NetworkID(node3.Network), nodeacls.NodeID(node3.ID.String()), acls.NotAllowed) + assert.Nil(t, err) + nodeACL.Save(acls.ContainerID(node3.Network), acls.AclID(node3.ID.String())) + currentACL, err = nodeacls.FetchAllACLs(nodeacls.NetworkID(node3.Network)) + assert.Nil(t, err) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node1.ID.String())][acls.AclID(node3.ID.String())]) + assert.Equal(t, acls.NotAllowed, currentACL[acls.AclID(node2.ID.String())][acls.AclID(node3.ID.String())]) }) t.Run("node acls removed", func(t *testing.T) { - err := nodeacls.RemoveNodeACL(node1.Network, node1.ID.String()) + retNetworkACL, err := nodeacls.RemoveNodeACL(nodeacls.NetworkID(node1.Network), nodeacls.NodeID(node1.ID.String())) assert.Nil(t, err) + assert.NotNil(t, retNetworkACL) + assert.Equal(t, acls.NotPresent, retNetworkACL[acls.AclID(node2.ID.String())][acls.AclID(node1.ID.String())]) }) deleteAllNodes() } diff --git a/converters/node_acl.go b/converters/node_acl.go deleted file mode 100644 index 935d1bf8..00000000 --- a/converters/node_acl.go +++ /dev/null @@ -1,38 +0,0 @@ -package converters - -import ( - "github.com/gravitl/netmaker/logic/nodeacls" - "github.com/gravitl/netmaker/schema" - "gorm.io/datatypes" -) - -func ToSchemaNetworkACL(networkID string, aclContainer nodeacls.ACLContainer) schema.NetworkACL { - _networkACL := schema.NetworkACL{ - ID: networkID, - Access: datatypes.JSONType[map[string]map[string]byte]{}, - } - - for nodeID := range aclContainer { - _networkACL.Access.Data()[string(nodeID)] = make(map[string]byte) - - for peerID := range aclContainer[nodeID] { - _networkACL.Access.Data()[string(nodeID)][string(peerID)] = aclContainer[nodeID][peerID] - } - } - - return _networkACL -} - -func ToACLContainer(_networkACL schema.NetworkACL) nodeacls.ACLContainer { - var aclContainer = nodeacls.ACLContainer{} - - for nodeID := range _networkACL.Access.Data() { - aclContainer[nodeacls.AclID(nodeID)] = make(nodeacls.ACL) - - for peerID := range _networkACL.Access.Data()[nodeID] { - aclContainer[nodeacls.AclID(nodeID)][nodeacls.AclID(peerID)] = _networkACL.Access.Data()[nodeID][peerID] - } - } - - return aclContainer -} diff --git a/logic/acls/common.go b/logic/acls/common.go new file mode 100644 index 00000000..bb35123a --- /dev/null +++ b/logic/acls/common.go @@ -0,0 +1,199 @@ +package acls + +import ( + "encoding/json" + "maps" + "sync" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/exp/slog" +) + +var ( + aclCacheMutex = &sync.RWMutex{} + aclCacheMap = make(map[ContainerID]ACLContainer) + AclMutex = &sync.RWMutex{} +) + +func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) { + aclCacheMutex.RLock() + aclCont, ok = aclCacheMap[containerID] + aclCacheMutex.RUnlock() + return +} + +func storeAclContainerInCache(containerID ContainerID, aclContainer ACLContainer) { + aclCacheMutex.Lock() + aclCacheMap[containerID] = aclContainer + aclCacheMutex.Unlock() +} + +func DeleteAclFromCache(containerID ContainerID) { + aclCacheMutex.Lock() + delete(aclCacheMap, containerID) + aclCacheMutex.Unlock() +} + +// == type functions == + +// ACL.Allow - allows access by ID in memory +func (acl ACL) Allow(ID AclID) { + AclMutex.Lock() + defer AclMutex.Unlock() + acl[ID] = Allowed +} + +// ACL.DisallowNode - disallows access by ID in memory +func (acl ACL) Disallow(ID AclID) { + AclMutex.Lock() + defer AclMutex.Unlock() + acl[ID] = NotAllowed +} + +// ACL.Remove - removes a node from a ACL in memory +func (acl ACL) Remove(ID AclID) { + AclMutex.Lock() + defer AclMutex.Unlock() + delete(acl, ID) +} + +// ACL.Update - updates a ACL in DB +func (acl ACL) Save(containerID ContainerID, ID AclID) (ACL, error) { + return upsertACL(containerID, ID, acl) +} + +// ACL.IsAllowed - sees if ID is allowed in referring ACL +func (acl ACL) IsAllowed(ID AclID) (allowed bool) { + AclMutex.Lock() + allowed = acl[ID] == Allowed + AclMutex.Unlock() + return +} + +// ACLContainer.UpdateACL - saves the state of a ACL in the ACLContainer in memory +func (aclContainer ACLContainer) UpdateACL(ID AclID, acl ACL) ACLContainer { + AclMutex.Lock() + defer AclMutex.Unlock() + aclContainer[ID] = acl + return aclContainer +} + +// ACLContainer.RemoveACL - removes the state of a ACL in the ACLContainer in memory +func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer { + AclMutex.Lock() + defer AclMutex.Unlock() + delete(aclContainer, ID) + return aclContainer +} + +// ACLContainer.ChangeAccess - changes the relationship between two nodes in memory +func (networkACL ACLContainer) ChangeAccess(ID1, ID2 AclID, value byte) { + AclMutex.Lock() + defer AclMutex.Unlock() + if _, ok := networkACL[ID1]; !ok { + slog.Error("ACL missing for ", "id", ID1) + return + } + if _, ok := networkACL[ID2]; !ok { + slog.Error("ACL missing for ", "id", ID2) + return + } + if _, ok := networkACL[ID1][ID2]; !ok { + slog.Error("ACL missing for ", "id1", ID1, "id2", ID2) + return + } + if _, ok := networkACL[ID2][ID1]; !ok { + slog.Error("ACL missing for ", "id2", ID2, "id1", ID1) + return + } + networkACL[ID1][ID2] = value + networkACL[ID2][ID1] = value +} + +// ACLContainer.Save - saves the state of a ACLContainer to the db +func (aclContainer ACLContainer) Save(containerID ContainerID) (ACLContainer, error) { + return upsertACLContainer(containerID, aclContainer) +} + +// ACLContainer.New - saves the state of a ACLContainer to the db +func (aclContainer ACLContainer) New(containerID ContainerID) (ACLContainer, error) { + return upsertACLContainer(containerID, nil) +} + +// ACLContainer.Get - saves the state of a ACLContainer to the db +func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, error) { + return fetchACLContainer(containerID) +} + +// == private == + +// fetchACLContainer - fetches all current rules in given ACL container +func fetchACLContainer(containerID ContainerID) (ACLContainer, error) { + AclMutex.RLock() + defer AclMutex.RUnlock() + if servercfg.CacheEnabled() { + if aclContainer, ok := fetchAclContainerFromCache(containerID); ok { + return maps.Clone(aclContainer), nil + } + } + aclJson, err := fetchACLContainerJson(ContainerID(containerID)) + if err != nil { + return nil, err + } + var currentNetworkACL ACLContainer + if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { + return nil, err + } + if servercfg.CacheEnabled() { + storeAclContainerInCache(containerID, currentNetworkACL) + } + return maps.Clone(currentNetworkACL), nil +} + +// fetchACLContainerJson - fetch the current ACL of given container except in json string +func fetchACLContainerJson(containerID ContainerID) (ACLJson, error) { + currentACLs, err := database.FetchRecord(database.NODE_ACLS_TABLE_NAME, string(containerID)) + if err != nil { + return ACLJson(""), err + } + return ACLJson(currentACLs), nil +} + +// upsertACL - applies a ACL to the db, overwrites or creates +func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) { + currentNetACL, err := fetchACLContainer(containerID) + if err != nil { + return acl, err + } + currentNetACL[ID] = acl + _, err = upsertACLContainer(containerID, currentNetACL) + return acl, err +} + +// upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the container ID +// if nil, create it +func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACLContainer, error) { + AclMutex.Lock() + defer AclMutex.Unlock() + if aclContainer == nil { + aclContainer = make(ACLContainer) + } + + err := database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME) + if err != nil { + return aclContainer, err + } + if servercfg.CacheEnabled() { + storeAclContainerInCache(containerID, aclContainer) + } + return aclContainer, nil +} + +func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson { + data, err := json.Marshal(networkACL) + if err != nil { + return "" + } + return ACLJson(data) +} diff --git a/logic/acls/nodeacls/modify.go b/logic/acls/nodeacls/modify.go new file mode 100644 index 00000000..1c2de672 --- /dev/null +++ b/logic/acls/nodeacls/modify.go @@ -0,0 +1,102 @@ +package nodeacls + +import ( + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/servercfg" +) + +// CreateNodeACL - inserts or updates a node ACL on given network and adds to state +func CreateNodeACL(networkID NetworkID, nodeID NodeID, defaultVal byte) (acls.ACL, error) { + if defaultVal != acls.NotAllowed && defaultVal != acls.Allowed { + defaultVal = acls.NotAllowed + } + var currentNetworkACL, err = FetchAllACLs(networkID) + if err != nil { + if database.IsEmptyRecord(err) { + currentNetworkACL, err = currentNetworkACL.New(acls.ContainerID(networkID)) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + acls.AclMutex.Lock() + var newNodeACL = make(acls.ACL) + for existingNodeID := range currentNetworkACL { + if currentNetworkACL[existingNodeID] == nil { + currentNetworkACL[existingNodeID] = make(acls.ACL) + } + currentNetworkACL[existingNodeID][acls.AclID(nodeID)] = defaultVal // set the old nodes to default value for new node + newNodeACL[existingNodeID] = defaultVal // set the old nodes in new node ACL to default value + } + currentNetworkACL[acls.AclID(nodeID)] = newNodeACL // append the new node's ACL + acls.AclMutex.Unlock() + retNetworkACL, err := currentNetworkACL.Save(acls.ContainerID(networkID)) // insert into db + if err != nil { + return nil, err + } + return retNetworkACL[acls.AclID(nodeID)], nil +} + +// AllowNode - allow access between two nodes in memory +func AllowNodes(networkID NetworkID, node1, node2 NodeID) (acls.ACLContainer, error) { + container, err := FetchAllACLs(networkID) + if err != nil { + return nil, err + } + container[acls.AclID(node1)].Allow(acls.AclID(node2)) + container[acls.AclID(node2)].Allow(acls.AclID(node1)) + return container, nil +} + +// DisallowNodes - deny access between two nodes +func DisallowNodes(networkID NetworkID, node1, node2 NodeID) (acls.ACLContainer, error) { + container, err := FetchAllACLs(networkID) + if err != nil { + return nil, err + } + container[acls.AclID(node1)].Disallow(acls.AclID(node2)) + container[acls.AclID(node2)].Disallow(acls.AclID(node1)) + return container, nil +} + +// UpdateNodeACL - updates a node's ACL in state +func UpdateNodeACL(networkID NetworkID, nodeID NodeID, acl acls.ACL) (acls.ACL, error) { + var currentNetworkACL, err = FetchAllACLs(networkID) + if err != nil { + return nil, err + } + acls.AclMutex.Lock() + currentNetworkACL[acls.AclID(nodeID)] = acl + acls.AclMutex.Unlock() + return currentNetworkACL[acls.AclID(nodeID)].Save(acls.ContainerID(networkID), acls.AclID(nodeID)) +} + +// RemoveNodeACL - removes a specific Node's ACL, returns the NetworkACL and error +func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error) { + var currentNetworkACL, err = FetchAllACLs(networkID) + if err != nil { + return nil, err + } + for currentNodeID := range currentNetworkACL { + if NodeID(currentNodeID) != nodeID { + currentNetworkACL[currentNodeID].Remove(acls.AclID(nodeID)) + } + } + delete(currentNetworkACL, acls.AclID(nodeID)) + return currentNetworkACL.Save(acls.ContainerID(networkID)) +} + +// DeleteACLContainer - removes an ACLContainer state from db +func DeleteACLContainer(network NetworkID) error { + err := database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network)) + if err != nil { + return err + } + if servercfg.CacheEnabled() { + acls.DeleteAclFromCache(acls.ContainerID(network)) + } + return nil +} diff --git a/logic/acls/nodeacls/retrieve.go b/logic/acls/nodeacls/retrieve.go new file mode 100644 index 00000000..84895f44 --- /dev/null +++ b/logic/acls/nodeacls/retrieve.go @@ -0,0 +1,76 @@ +package nodeacls + +import ( + "encoding/json" + "fmt" + "maps" + "sync" + + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/servercfg" +) + +var NodesAllowedACLMutex = &sync.Mutex{} + +// AreNodesAllowed - checks if nodes are allowed to communicate in their network ACL +func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool { + if !servercfg.IsOldAclEnabled() { + return true + } + NodesAllowedACLMutex.Lock() + defer NodesAllowedACLMutex.Unlock() + var currentNetworkACL, err = FetchAllACLs(networkID) + if err != nil { + return false + } + var allowed bool + acls.AclMutex.Lock() + currNetworkACLNode1 := currentNetworkACL[acls.AclID(node1)] + currNetworkACLNode2 := currentNetworkACL[acls.AclID(node2)] + acls.AclMutex.Unlock() + allowed = currNetworkACLNode1.IsAllowed(acls.AclID(node2)) && currNetworkACLNode2.IsAllowed(acls.AclID(node1)) + return allowed +} + +// FetchNodeACL - fetches a specific node's ACL in a given network +func FetchNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACL, error) { + var currentNetworkACL, err = FetchAllACLs(networkID) + if err != nil { + return nil, err + } + var acl acls.ACL + acls.AclMutex.RLock() + if currentNetworkACL[acls.AclID(nodeID)] == nil { + acls.AclMutex.RUnlock() + return nil, fmt.Errorf("no node ACL present for node %s", nodeID) + } + acl = currentNetworkACL[acls.AclID(nodeID)] + acls.AclMutex.RUnlock() + return acl, nil +} + +// FetchNodeACLJson - fetches a node's acl in given network except returns the json string +func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (acls.ACLJson, error) { + currentNodeACL, err := FetchNodeACL(networkID, nodeID) + if err != nil { + return "", err + } + acls.AclMutex.RLock() + defer acls.AclMutex.RUnlock() + jsonData, err := json.Marshal(¤tNodeACL) + if err != nil { + return "", err + } + return acls.ACLJson(jsonData), nil +} + +// FetchAllACLs - fetchs all node +func FetchAllACLs(networkID NetworkID) (acls.ACLContainer, error) { + var err error + var currentNetworkACL acls.ACLContainer + currentNetworkACL, err = currentNetworkACL.Get(acls.ContainerID(networkID)) + if err != nil { + return nil, err + } + return maps.Clone(currentNetworkACL), nil +} diff --git a/logic/acls/nodeacls/types.go b/logic/acls/nodeacls/types.go new file mode 100644 index 00000000..2c3d825e --- /dev/null +++ b/logic/acls/nodeacls/types.go @@ -0,0 +1,14 @@ +package nodeacls + +import ( + "github.com/gravitl/netmaker/logic/acls" +) + +type ( + // NodeACL - interface for NodeACLs + NodeACL acls.ACL + // NodeID - node ID for ACLs + NodeID acls.AclID + // NetworkID - ACL container based on network ID for nodes + NetworkID acls.ContainerID +) diff --git a/logic/nodeacls/types.go b/logic/acls/types.go similarity index 97% rename from logic/nodeacls/types.go rename to logic/acls/types.go index 923585b8..57364508 100644 --- a/logic/nodeacls/types.go +++ b/logic/acls/types.go @@ -1,4 +1,4 @@ -package nodeacls +package acls var ( // NotPresent - 0 - not present (default) diff --git a/logic/clients.go b/logic/clients.go index e1fb70a8..dcbddcba 100644 --- a/logic/clients.go +++ b/logic/clients.go @@ -1,13 +1,9 @@ package logic import ( - "context" "errors" - "fmt" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic/nodeacls" - "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/logic/acls" + "golang.org/x/exp/slog" "sort" "github.com/gravitl/netmaker/models" @@ -30,31 +26,22 @@ var ( } SetClientDefaultACLs = func(ec *models.ExtClient) error { // allow all on CE - _networkACL := &schema.NetworkACL{ - ID: ec.Network, - } - err := _networkACL.Get(db.WithContext(context.TODO())) + networkAcls := acls.ACLContainer{} + networkAcls, err := networkAcls.Get(acls.ContainerID(ec.Network)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acls: %s", _networkACL.ID, err.Error())) + slog.Error("failed to get network acls", "error", err) return err } - - _networkACL.Access.Data()[ec.ClientID] = make(map[string]byte) - - for peerID := range _networkACL.Access.Data() { - _networkACL.Access.Data()[peerID][ec.ClientID] = nodeacls.Allowed - _networkACL.Access.Data()[ec.ClientID][peerID] = nodeacls.Allowed + networkAcls[acls.AclID(ec.ClientID)] = make(acls.ACL) + for objId := range networkAcls { + networkAcls[objId][acls.AclID(ec.ClientID)] = acls.Allowed + networkAcls[acls.AclID(ec.ClientID)][objId] = acls.Allowed } - - // delete self loop. - delete(_networkACL.Access.Data()[ec.ClientID], ec.ClientID) - - err = _networkACL.Update(db.WithContext(context.TODO())) - if err != nil { - logger.Log(0, fmt.Sprintf("failed to update network (%s) acls: %s", _networkACL.ID, err.Error())) + delete(networkAcls[acls.AclID(ec.ClientID)], acls.AclID(ec.ClientID)) + if _, err = networkAcls.Save(acls.ContainerID(ec.Network)); err != nil { + slog.Error("failed to update network acls", "error", err) return err } - return nil } SetClientACLs = func(ec *models.ExtClient, newACLs map[string]struct{}) { diff --git a/logic/extpeers.go b/logic/extpeers.go index 7e1f470c..298aa8bb 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -1,12 +1,10 @@ package logic import ( - "context" "encoding/json" "errors" "fmt" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/logic/acls" "net" "reflect" "sort" @@ -170,24 +168,19 @@ func DeleteExtClientAndCleanup(extClient models.ExtClient) error { return err } - _networkACL := &schema.NetworkACL{ - ID: extClient.Network, - } - err = _networkACL.Get(db.WithContext(context.TODO())) + //update ACLs + var networkAcls acls.ACLContainer + networkAcls, err = networkAcls.Get(acls.ContainerID(extClient.Network)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acls: %s", _networkACL.ID, err.Error())) + slog.Error("DeleteExtClientAndCleanup-update network acls: ", "Error", err.Error()) return err } - - for peerID := range _networkACL.Access.Data() { - delete(_networkACL.Access.Data()[peerID], extClient.ClientID) + for objId := range networkAcls { + delete(networkAcls[objId], acls.AclID(extClient.ClientID)) } - - delete(_networkACL.Access.Data(), extClient.ClientID) - - err = _networkACL.Update(db.WithContext(context.TODO())) - if err != nil { - logger.Log(0, fmt.Sprintf("failed to update network (%s) acls: %s", _networkACL.ID, err.Error())) + delete(networkAcls, acls.AclID(extClient.ClientID)) + if _, err = networkAcls.Save(acls.ContainerID(extClient.Network)); err != nil { + slog.Error("DeleteExtClientAndCleanup-update network acls:", "Error", err.Error()) return err } diff --git a/logic/networks.go b/logic/networks.go index 2e00791a..ae722977 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gravitl/netmaker/converters" "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/schema" "net" "strings" @@ -67,10 +68,8 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error { } } - _networkACL := &schema.NetworkACL{ - ID: network, - } - err = _networkACL.Delete(db.WithContext(context.TODO())) + // remove ACL for network + err = nodeacls.DeleteACLContainer(nodeacls.NetworkID(network)) if err != nil { logger.Log(1, "failed to remove the node acls during network delete for network,", network) } diff --git a/logic/nodeacls/node_acls.go b/logic/nodeacls/node_acls.go deleted file mode 100644 index 70809dbb..00000000 --- a/logic/nodeacls/node_acls.go +++ /dev/null @@ -1,168 +0,0 @@ -package nodeacls - -import ( - "context" - "errors" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" - "gorm.io/gorm" -) - -// CreateNodeACL - inserts or updates a node ACL on given network and adds to state -func CreateNodeACL(networkID, nodeID string, defaultVal byte) error { - if defaultVal != NotAllowed && defaultVal != Allowed { - defaultVal = NotAllowed - } - - var commit bool - dbctx := db.BeginTx(context.TODO()) - defer func() { - if commit { - db.FromContext(dbctx).Commit() - } else { - db.FromContext(dbctx).Rollback() - } - }() - - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - err := _networkACL.Get(dbctx) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - err = _networkACL.Create(dbctx) - if err != nil { - return err - } - } else { - return err - } - } - - _networkACL.Access.Data()[nodeID] = make(map[string]byte) - - for peerID := range _networkACL.Access.Data() { - _networkACL.Access.Data()[peerID][nodeID] = defaultVal - _networkACL.Access.Data()[nodeID][peerID] = defaultVal - } - - err = _networkACL.Update(dbctx) - if err != nil { - return err - } - - commit = true - return nil -} - -// AreNodesAllowed - checks if nodes are allowed to communicate in their network ACL -func AreNodesAllowed(networkID, node1, node2 string) bool { - if !servercfg.IsOldAclEnabled() { - return true - } - - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - - err := _networkACL.Get(db.WithContext(context.TODO())) - if err != nil { - return false - } - - _, ok := _networkACL.Access.Data()[node1] - if !ok { - return false - } - - _, ok = _networkACL.Access.Data()[node2] - if !ok { - return false - } - - _, ok = _networkACL.Access.Data()[node1][node2] - if !ok { - return false - } - - _, ok = _networkACL.Access.Data()[node2][node1] - if !ok { - return false - } - - node1Allows := _networkACL.Access.Data()[node1][node2] == Allowed - node2Allows := _networkACL.Access.Data()[node2][node1] == Allowed - - return node1Allows && node2Allows -} - -// ChangeAccess - changes the relationship between two nodes. -func ChangeAccess(networkID, nodeID1, nodeID2 string, value byte) error { - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - - var commit bool - dbctx := db.BeginTx(context.TODO()) - defer func() { - if commit { - db.FromContext(dbctx).Commit() - } else { - db.FromContext(dbctx).Rollback() - } - }() - - err := _networkACL.Get(dbctx) - if err != nil { - return err - } - - if _networkACL.Access.Data()[nodeID1] == nil { - _networkACL.Access.Data()[nodeID1] = make(map[string]byte) - } - - if _networkACL.Access.Data()[nodeID2] == nil { - _networkACL.Access.Data()[nodeID2] = make(map[string]byte) - } - - _networkACL.Access.Data()[nodeID1][nodeID2] = value - _networkACL.Access.Data()[nodeID2][nodeID1] = value - - err = _networkACL.Update(dbctx) - if err != nil { - return err - } - - commit = true - return nil -} - -// RemoveNodeACL - removes a specific Node's ACL. -func RemoveNodeACL(networkID, nodeID string) error { - var commit bool - dbctx := db.BeginTx(context.TODO()) - defer func() { - if commit { - db.FromContext(dbctx).Commit() - } else { - db.FromContext(dbctx).Rollback() - } - }() - - _networkACL := &schema.NetworkACL{ - ID: networkID, - } - err := _networkACL.Get(dbctx) - if err != nil { - return err - } - - delete(_networkACL.Access.Data(), nodeID) - - for peerID := range _networkACL.Access.Data() { - delete(_networkACL.Access.Data()[peerID], nodeID) - } - - return _networkACL.Update(dbctx) -} diff --git a/logic/nodes.go b/logic/nodes.go index f5fd6dcc..3179d3d7 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -7,7 +7,8 @@ import ( "fmt" "github.com/gravitl/netmaker/converters" "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logic/nodeacls" + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/schema" "net" "time" @@ -207,7 +208,7 @@ func DeleteNodeByID(node *models.Node) error { if servercfg.IsDNSMode() { SetDNS() } - err = nodeacls.RemoveNodeACL(node.Network, node.ID.String()) + _, err = nodeacls.RemoveNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String())) if err != nil { // ignoring for now, could hit a nil pointer if delete called twice logger.Log(2, "attempted to remove node ACL for node", node.ID.String()) @@ -451,14 +452,14 @@ func createNode(node *models.Node) error { SetNodeDefaults(node, true) - defaultACLVal := nodeacls.Allowed + defaultACLVal := acls.Allowed _network := &schema.Network{ ID: node.Network, } err = _network.Get(db.WithContext(context.TODO())) if err == nil { if _network.DefaultACL != "yes" { - defaultACLVal = nodeacls.NotAllowed + defaultACLVal = acls.NotAllowed } } @@ -509,7 +510,7 @@ func createNode(node *models.Node) error { return err } - err = nodeacls.CreateNodeACL(node.Network, node.ID.String(), defaultACLVal) + _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal) if err != nil { logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error()) return err diff --git a/logic/peers.go b/logic/peers.go index 6420e22e..fe709f27 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/gravitl/netmaker/logic/nodeacls" "net" "net/netip" "time" @@ -13,6 +12,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" @@ -89,7 +89,7 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) { if peer.Action != models.NODE_DELETE && !peer.PendingDelete && peer.Connected && - nodeacls.AreNodesAllowed(node.Network, node.ID.String(), peer.ID.String()) && + nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) && (allowedToComm) { networkPeersInfo[peerHost.PublicKey.String()] = models.IDandAddr{ @@ -355,7 +355,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N if peer.Action != models.NODE_DELETE && !peer.PendingDelete && peer.Connected && - nodeacls.AreNodesAllowed(node.Network, node.ID.String(), peer.ID.String()) && + nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) && (allowedToComm) && (deletedNode == nil || (peer.ID.String() != deletedNode.ID.String())) { peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection diff --git a/logic/relay.go b/logic/relay.go index ce944518..0150925d 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -4,12 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/gravitl/netmaker/logic/nodeacls" "net" "github.com/google/uuid" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) @@ -236,7 +236,7 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe continue } AddEgressInfoToPeerByAccess(relayed, &peer, eli, acls, defaultPolicy.Enabled) - if nodeacls.AreNodesAllowed(relayed.Network, relayed.ID.String(), peer.ID.String()) { + if nodeacls.AreNodesAllowed(nodeacls.NetworkID(relayed.Network), nodeacls.NodeID(relayed.ID.String()), nodeacls.NodeID(peer.ID.String())) { allowedIPs = append(allowedIPs, GetAllowedIPs(relayed, &peer, nil)...) } } diff --git a/migrate/migrate.go b/migrate/migrate.go index 7792e038..4af46ec5 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -1,12 +1,9 @@ package migrate import ( - "context" "encoding/json" "fmt" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logic/nodeacls" - "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/logic/acls" "log" "os" "time" @@ -326,12 +323,13 @@ func updateAcls() { // get current acls per network for _, network := range networks { - _networkACL := &schema.NetworkACL{ - ID: network.NetID, - } - err = _networkACL.Get(db.WithContext(context.TODO())) + var networkAcl acls.ACLContainer + networkAcl, err := networkAcl.Get(acls.ContainerID(network.NetID)) if err != nil { - logger.Log(0, fmt.Sprintf("failed to get network (%s) acl during acl migration: %s", network.NetID, err.Error())) + if database.IsEmptyRecord(err) { + continue + } + slog.Error(fmt.Sprintf("error during acls migration. error getting acls for network: %s", network.NetID), "error", err) continue } @@ -343,14 +341,14 @@ func updateAcls() { continue } - clientsMap := make(map[string]struct{}) + clientsIdMap := make(map[string]struct{}) for _, client := range clients { - clientsMap[client.ClientID] = struct{}{} + clientsIdMap[client.ClientID] = struct{}{} } - nodesMap := make(map[string]struct{}) - for nodeID := range _networkACL.Access.Data() { - nodesMap[nodeID] = struct{}{} + nodeIdsMap := make(map[string]struct{}) + for nodeId := range networkAcl { + nodeIdsMap[string(nodeId)] = struct{}{} } /* initially, networkACL has only node acls so we add client acls to it @@ -379,27 +377,24 @@ func updateAcls() { } */ for _, client := range clients { - _networkACL.Access.Data()[client.ClientID] = make(map[string]byte) + networkAcl[acls.AclID(client.ClientID)] = acls.ACL{} // add client values to node acls and create client acls with node values - for nodeID := range _networkACL.Access.Data() { + for id, nodeAcl := range networkAcl { // skip if not a node - if _, ok := nodesMap[nodeID]; !ok { + if _, ok := nodeIdsMap[string(id)]; !ok { continue } - - if _networkACL.Access.Data()[nodeID] == nil { - logger.Log(0, fmt.Sprintf("bad data: nil acl for node (%s)", nodeID)) + if nodeAcl == nil { + slog.Warn("acls migration bad data: nil node acl", "node", id, "network", network.NetID) continue } - - _networkACL.Access.Data()[nodeID][client.ClientID] = nodeacls.Allowed - _networkACL.Access.Data()[client.ClientID][nodeID] = nodeacls.Allowed - + nodeAcl[acls.AclID(client.ClientID)] = acls.Allowed + networkAcl[acls.AclID(client.ClientID)][id] = acls.Allowed if client.DeniedACLs == nil { continue - } else if _, ok := client.DeniedACLs[nodeID]; ok { - _networkACL.Access.Data()[nodeID][client.ClientID] = nodeacls.NotAllowed - _networkACL.Access.Data()[client.ClientID][nodeID] = nodeacls.NotAllowed + } else if _, ok := client.DeniedACLs[string(id)]; ok { + nodeAcl[acls.AclID(client.ClientID)] = acls.NotAllowed + networkAcl[acls.AclID(client.ClientID)][id] = acls.NotAllowed } } @@ -409,43 +404,40 @@ func updateAcls() { continue } - _networkACL.Access.Data()[client.ClientID][c.ClientID] = nodeacls.Allowed + networkAcl[acls.AclID(client.ClientID)][acls.AclID(c.ClientID)] = acls.Allowed if client.DeniedACLs == nil { continue } else if _, ok := client.DeniedACLs[c.ClientID]; ok { - _networkACL.Access.Data()[client.ClientID][c.ClientID] = nodeacls.NotAllowed + networkAcl[acls.AclID(client.ClientID)][acls.AclID(c.ClientID)] = acls.NotAllowed } } // delete oneself from its own acl - delete(_networkACL.Access.Data()[client.ClientID], client.ClientID) + delete(networkAcl[acls.AclID(client.ClientID)], acls.AclID(client.ClientID)) } // remove non-existent client and node acls - for id := range _networkACL.Access.Data() { - if _, ok := nodesMap[id]; ok { + for objId := range networkAcl { + if _, ok := nodeIdsMap[string(objId)]; ok { continue } - - if _, ok := clientsMap[id]; ok { + if _, ok := clientsIdMap[string(objId)]; ok { continue } - - // remove all occurrences of id from all acls - for peerID := range _networkACL.Access.Data() { - delete(_networkACL.Access.Data()[peerID], id) + // remove all occurances of objId from all acls + for objId2 := range networkAcl { + delete(networkAcl[objId2], objId) } - - delete(_networkACL.Access.Data(), id) + delete(networkAcl, objId) } - err = _networkACL.Update(db.WithContext(context.TODO())) - if err != nil { - logger.Log(0, fmt.Sprintf("failed to migrate acls for network (%s): %s", network.NetID, err.Error())) + // save new acls + slog.Debug(fmt.Sprintf("(migration) saving new acls for network: %s", network.NetID), "networkAcl", networkAcl) + if _, err := networkAcl.Save(acls.ContainerID(network.NetID)); err != nil { + slog.Error(fmt.Sprintf("error during acls migration. error saving new acls for network: %s", network.NetID), "error", err) continue } - - logger.Log(1, fmt.Sprintf("acls migration succeeded for network (%s)", network.NetID)) + slog.Info(fmt.Sprintf("(migration) successfully saved new acls for network: %s", network.NetID)) } } @@ -531,13 +523,8 @@ func createDefaultTagsAndPolicies() { logic.CreateDefaultTags(models.NetworkID(network.NetID)) logic.CreateDefaultNetworkPolicies(network.NetID) // delete old remote access gws policy - - _acl := &schema.NetworkACL{ - ID: fmt.Sprintf("%s.%s", network.NetID, "all-remote-access-gws"), - } - _ = _acl.Delete(db.WithContext(context.TODO())) + logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.NetID, "all-remote-access-gws")}) } - logic.MigrateAclPolicies() } diff --git a/pro/logic/ext_acls.go b/pro/logic/ext_acls.go index c6e90e20..aa2c99a4 100644 --- a/pro/logic/ext_acls.go +++ b/pro/logic/ext_acls.go @@ -4,9 +4,11 @@ import ( "context" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/nodeacls" + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" + "golang.org/x/exp/slog" ) // DenyClientNode - add a denied node to an ext client's list @@ -62,27 +64,24 @@ func SetClientDefaultACLs(ec *models.ExtClient) error { return err } - _networkACL := &schema.NetworkACL{ - ID: ec.Network, - } - err = _networkACL.Get(db.WithContext(context.TODO())) + var networkAcls acls.ACLContainer + networkAcls, err = networkAcls.Get(acls.ContainerID(ec.Network)) if err != nil { + slog.Error("failed to get network acls", "error", err) return err } - _networkACL.Access.Data()[ec.ClientID] = make(map[string]byte) + networkAcls[acls.AclID(ec.ClientID)] = make(acls.ACL) for _, _node := range _networkNodes { if _network.DefaultACL == "no" || _node.DefaultACL == "no" { DenyClientNode(ec, _node.ID) - - _networkACL.Access.Data()[ec.ClientID][_node.ID] = nodeacls.NotAllowed - _networkACL.Access.Data()[_node.ID][ec.ClientID] = nodeacls.NotAllowed + networkAcls[acls.AclID(ec.ClientID)][acls.AclID(_node.ID)] = acls.NotAllowed + networkAcls[acls.AclID(_node.ID)][acls.AclID(ec.ClientID)] = acls.NotAllowed } else { RemoveDeniedNodeFromClient(ec, _node.ID) - - _networkACL.Access.Data()[ec.ClientID][_node.ID] = nodeacls.Allowed - _networkACL.Access.Data()[_node.ID][ec.ClientID] = nodeacls.Allowed + networkAcls[acls.AclID(ec.ClientID)][acls.AclID(_node.ID)] = acls.Allowed + networkAcls[acls.AclID(_node.ID)][acls.AclID(ec.ClientID)] = acls.Allowed } } @@ -92,19 +91,17 @@ func SetClientDefaultACLs(ec *models.ExtClient) error { } for _, client := range extClients { - if _networkACL.Access.Data()[client.ClientID] == nil { - _networkACL.Access.Data()[client.ClientID] = make(map[string]byte) - } - // TODO: revisit when client-client acls are supported - _networkACL.Access.Data()[ec.ClientID][client.ClientID] = nodeacls.Allowed - _networkACL.Access.Data()[client.ClientID][ec.ClientID] = nodeacls.Allowed + networkAcls[acls.AclID(ec.ClientID)][acls.AclID(client.ClientID)] = acls.Allowed + networkAcls[acls.AclID(client.ClientID)][acls.AclID(ec.ClientID)] = acls.Allowed } - // remove access policy to self. - delete(_networkACL.Access.Data()[ec.ClientID], ec.ClientID) - - return _networkACL.Update(db.WithContext(context.TODO())) + delete(networkAcls[acls.AclID(ec.ClientID)], acls.AclID(ec.ClientID)) // remove oneself + if _, err = networkAcls.Save(acls.ContainerID(ec.Network)); err != nil { + slog.Error("failed to update network acls", "error", err) + return err + } + return nil } // SetClientACLs - overwrites an ext client's ACL @@ -126,6 +123,11 @@ func UpdateProNodeACLs(node *models.Node) error { nodeID := node.ID.String() + currentACLs, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(_network.ID)) + if err != nil { + return err + } + for _, _node := range _networkNodes { if _node.ID == nodeID { continue @@ -134,17 +136,12 @@ func UpdateProNodeACLs(node *models.Node) error { // both allow - allow // either 1 denies - deny if node.DoesACLDeny() || _node.DefaultACL == "no" { - err = nodeacls.ChangeAccess(node.Network, nodeID, _node.ID, nodeacls.NotAllowed) - if err != nil { - return err - } + currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(_node.ID), acls.NotAllowed) } else if node.DoesACLAllow() || _node.DefaultACL == "yes" { - err = nodeacls.ChangeAccess(node.Network, nodeID, _node.ID, nodeacls.Allowed) - if err != nil { - return err - } + currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(_node.ID), acls.Allowed) } } + _, err = currentACLs.Save(acls.ContainerID(node.Network)) return nil } diff --git a/schema/models.go b/schema/models.go index e42f883f..eb29cdc4 100644 --- a/schema/models.go +++ b/schema/models.go @@ -6,7 +6,6 @@ func ListModels() []interface{} { &Host{}, &Network{}, &Node{}, - &NetworkACL{}, &ACL{}, &Job{}, &Egress{}, diff --git a/schema/network_acl.go b/schema/network_acl.go deleted file mode 100644 index c4218107..00000000 --- a/schema/network_acl.go +++ /dev/null @@ -1,49 +0,0 @@ -package schema - -import ( - "context" - "github.com/gravitl/netmaker/db" - "gorm.io/datatypes" -) - -type NetworkACL struct { - ID string `gorm:"primaryKey"` - Access datatypes.JSONType[map[string]map[string]byte] -} - -func (n *NetworkACL) TableName() string { - return "network_acls_v1" -} - -func (n *NetworkACL) Create(ctx context.Context) error { - if n.Access.Data() == nil { - n.Access = datatypes.NewJSONType(map[string]map[string]byte{}) - } - - return db.FromContext(ctx).Model(&NetworkACL{}).Create(n).Error -} - -func (n *NetworkACL) Get(ctx context.Context) error { - err := db.FromContext(ctx).Model(n).First(n).Error - if err != nil { - return err - } - - if n.Access.Data() == nil { - n.Access = datatypes.NewJSONType(map[string]map[string]byte{}) - } - - return nil -} - -func (n *NetworkACL) Update(ctx context.Context) error { - if n.Access.Data() == nil { - n.Access = datatypes.NewJSONType(map[string]map[string]byte{}) - } - - return db.FromContext(ctx).Model(n).Updates(n).Error -} - -func (n *NetworkACL) Delete(ctx context.Context) error { - return db.FromContext(ctx).Model(n).Delete(n).Error -} diff --git a/serverctl/serverctl.go b/serverctl/serverctl.go index 3f228ba7..9cf2ba38 100644 --- a/serverctl/serverctl.go +++ b/serverctl/serverctl.go @@ -1,10 +1,13 @@ package serverctl import ( + "strings" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/nodeacls" + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/logic/acls/nodeacls" "golang.org/x/exp/slog" ) @@ -40,9 +43,11 @@ func setNodeDefaults() error { logic.SetNodeDefaults(&nodes[i], false) logic.UpdateNode(&nodes[i], &nodes[i]) - err = nodeacls.CreateNodeACL(nodes[i].Network, nodes[i].ID.String(), nodeacls.Allowed) - if err != nil { - logger.Log(1, "could not create a default ACL for node", nodes[i].ID.String()) + currentNodeACL, err := nodeacls.FetchNodeACL(nodeacls.NetworkID(nodes[i].Network), nodeacls.NodeID(nodes[i].ID.String())) + if (err != nil && (database.IsEmptyRecord(err) || strings.Contains(err.Error(), "no node ACL present"))) || currentNodeACL == nil { + if _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(nodes[i].Network), nodeacls.NodeID(nodes[i].ID.String()), acls.Allowed); err != nil { + logger.Log(1, "could not create a default ACL for node", nodes[i].ID.String()) + } } } return nil