diff --git a/controllers/dns_test.go b/controllers/dns_test.go index 8a9b886f..e070cfed 100644 --- a/controllers/dns_test.go +++ b/controllers/dns_test.go @@ -51,8 +51,7 @@ func TestGetNodeDNS(t *testing.T) { createNet() createHost() t.Run("NoNodes", func(t *testing.T) { - dns, err := logic.GetNodeDNS("skynet") - assert.EqualError(t, err, "could not find any records") + dns, _ := logic.GetNodeDNS("skynet") assert.Equal(t, []models.DNSEntry(nil), dns) }) t.Run("NodeExists", func(t *testing.T) { diff --git a/controllers/ext_client.go b/controllers/ext_client.go index d324cc42..7552dcdb 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -10,7 +10,6 @@ import ( "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/pro" @@ -102,7 +101,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) { clients := []models.ExtClient{} var err error if len(networksSlice) > 0 && networksSlice[0] == logic.ALL_NETWORK_ACCESS { - clients, err = functions.GetAllExtClients() + clients, err = logic.GetAllExtClients() if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, "failed to get all extclients: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) diff --git a/controllers/hosts.go b/controllers/hosts.go index da329c84..f9f6abfe 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -49,38 +49,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - //isMasterAdmin := r.Header.Get("ismaster") == "yes" - //user, err := logic.GetUser(r.Header.Get("user")) - //if err != nil && !isMasterAdmin { - // logger.Log(0, r.Header.Get("user"), "failed to fetch user: ", err.Error()) - // logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - // return - //} - // return JSON/API formatted hosts - //ret := []models.ApiHost{} apiHosts := logic.GetAllHostsAPI(currentHosts[:]) logger.Log(2, r.Header.Get("user"), "fetched all hosts") - //for _, host := range apiHosts { - // nodes := host.Nodes - // // work on the copy - // host.Nodes = []string{} - // for _, nid := range nodes { - // node, err := logic.GetNodeByID(nid) - // if err != nil { - // logger.Log(0, r.Header.Get("user"), "failed to fetch node: ", err.Error()) - // // TODO find the reason for the DB error, skip this node for now - // continue - // } - // if !isMasterAdmin && !logic.UserHasNetworksAccess([]string{node.Network}, user) { - // continue - // } - // host.Nodes = append(host.Nodes, nid) - // } - // // add to the response only if has perms to some nodes / networks - // if len(host.Nodes) > 0 { - // ret = append(ret, host) - // } - //} logic.SortApiHosts(apiHosts[:]) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(apiHosts) diff --git a/controllers/node_test.go b/controllers/node_test.go index d4519f44..ba877276 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -217,6 +217,7 @@ func TestNodeACLs(t *testing.T) { } func deleteAllNodes() { + logic.ClearNodeCache() database.DeleteAllRecords(database.NODES_TABLE_NAME) } diff --git a/logic/acls/common.go b/logic/acls/common.go index 79d1bd67..8c28cb6a 100644 --- a/logic/acls/common.go +++ b/logic/acls/common.go @@ -2,10 +2,37 @@ package acls import ( "encoding/json" + "sync" "github.com/gravitl/netmaker/database" + "golang.org/x/exp/slog" ) +var ( + aclCacheMutex = &sync.RWMutex{} + aclCacheMap = make(map[ContainerID]ACLContainer) + aclMutex = &sync.RWMutex{} +) + +func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) { + aclCacheMutex.RLock() + aclCont, ok = aclCacheMap[containerID] + aclCacheMutex.RUnlock() + return +} + +func storeAclContainerInCache(containerID ContainerID, aclContainer ACLContainer) { + aclCacheMutex.Lock() + aclCacheMap[containerID] = aclContainer + aclCacheMutex.Unlock() +} + +func DeleteAclFromCache(containerID ContainerID) { + aclCacheMutex.Lock() + delete(aclCacheMap, containerID) + aclCacheMutex.Unlock() +} + // == type functions == // ACL.Allow - allows access by ID in memory @@ -52,6 +79,22 @@ func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer { // ACLContainer.ChangeAccess - changes the relationship between two nodes in memory func (networkACL ACLContainer) ChangeAccess(ID1, ID2 AclID, value byte) { + if _, ok := networkACL[ID1]; !ok { + slog.Error("ACL missing for ", "id", ID1) + return + } + if _, ok := networkACL[ID2]; !ok { + slog.Error("ACL missing for ", "id", ID2) + return + } + if _, ok := networkACL[ID1][ID2]; !ok { + slog.Error("ACL missing for ", "id1", ID1, "id2", ID2) + return + } + if _, ok := networkACL[ID2][ID1]; !ok { + slog.Error("ACL missing for ", "id2", ID2, "id1", ID1) + return + } networkACL[ID1][ID2] = value networkACL[ID2][ID1] = value } @@ -75,6 +118,11 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err // fetchACLContainer - fetches all current rules in given ACL container func fetchACLContainer(containerID ContainerID) (ACLContainer, error) { + aclMutex.RLock() + defer aclMutex.RUnlock() + if aclContainer, ok := fetchAclContainerFromCache(containerID); ok { + return aclContainer, nil + } aclJson, err := fetchACLContainerJson(ContainerID(containerID)) if err != nil { return nil, err @@ -83,6 +131,7 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) { if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { return nil, err } + storeAclContainerInCache(containerID, currentNetworkACL) return currentNetworkACL, nil } @@ -109,10 +158,18 @@ func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) { // upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the container ID // if nil, create it func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACLContainer, error) { + aclMutex.Lock() + defer aclMutex.Unlock() if aclContainer == nil { aclContainer = make(ACLContainer) } - return aclContainer, database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME) + + err := database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME) + if err != nil { + return aclContainer, err + } + storeAclContainerInCache(containerID, aclContainer) + return aclContainer, nil } func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson { diff --git a/logic/acls/nodeacls/modify.go b/logic/acls/nodeacls/modify.go index 42604341..e803bb65 100644 --- a/logic/acls/nodeacls/modify.go +++ b/logic/acls/nodeacls/modify.go @@ -83,5 +83,10 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error // DeleteACLContainer - removes an ACLContainer state from db func DeleteACLContainer(network NetworkID) error { - return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network)) + err := database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network)) + if err != nil { + return err + } + acls.DeleteAclFromCache(acls.ContainerID(network)) + return nil } diff --git a/logic/dns.go b/logic/dns.go index 146a72fe..0991c8ce 100644 --- a/logic/dns.go +++ b/logic/dns.go @@ -69,16 +69,12 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) { var dns []models.DNSEntry - collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + nodes, err := GetNetworkNodes(network) if err != nil { return dns, err } - for _, value := range collection { - var node models.Node - if err = json.Unmarshal([]byte(value), &node); err != nil { - continue - } + for _, node := range nodes { if node.Network != network { continue } diff --git a/logic/extpeers.go b/logic/extpeers.go index 928e1cde..643a8729 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -3,58 +3,56 @@ package logic import ( "encoding/json" "fmt" + "sync" "time" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// GetExtPeersList - gets the ext peers lists -func GetExtPeersList(node *models.Node) ([]models.ExtPeersResponse, error) { +var ( + extClientCacheMutex = &sync.RWMutex{} + extClientCacheMap = make(map[string]models.ExtClient) +) - var peers []models.ExtPeersResponse - records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) - - if err != nil { - return peers, err +func getAllExtClientsFromCache() (extClients []models.ExtClient) { + extClientCacheMutex.RLock() + for _, extclient := range extClientCacheMap { + extClients = append(extClients, extclient) } + extClientCacheMutex.RUnlock() + return +} - for _, value := range records { - var peer models.ExtPeersResponse - var extClient models.ExtClient - err = json.Unmarshal([]byte(value), &peer) - if err != nil { - logger.Log(2, "failed to unmarshal peer when getting ext peer list") - continue - } - err = json.Unmarshal([]byte(value), &extClient) - if err != nil { - logger.Log(2, "failed to unmarshal ext client") - continue - } +func deleteExtClientFromCache(key string) { + extClientCacheMutex.Lock() + delete(extClientCacheMap, key) + extClientCacheMutex.Unlock() +} - if extClient.Enabled && extClient.Network == node.Network && extClient.IngressGatewayID == node.ID.String() { - peers = append(peers, peer) - } - } - return peers, err +func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) { + extClientCacheMutex.RLock() + extclient, ok = extClientCacheMap[key] + extClientCacheMutex.RUnlock() + return +} + +func storeExtClientInCache(key string, extclient models.ExtClient) { + extClientCacheMutex.Lock() + extClientCacheMap[key] = extclient + extClientCacheMutex.Unlock() } // ExtClient.GetEgressRangesOnNetwork - returns the egress ranges on network of ext client func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) { var result []string - nodesData, err := database.FetchRecords(database.NODES_TABLE_NAME) + networkNodes, err := GetNetworkNodes(client.Network) if err != nil { return []string{}, err } - for _, nodeData := range nodesData { - var currentNode models.Node - if err = json.Unmarshal([]byte(nodeData), ¤tNode); err != nil { - continue - } + for _, currentNode := range networkNodes { if currentNode.Network != client.Network { continue } @@ -75,13 +73,25 @@ func DeleteExtClient(network string, clientid string) error { return err } err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key) - return err + if err != nil { + return err + } + deleteExtClientFromCache(key) + return nil } // GetNetworkExtClients - gets the ext clients of given network func GetNetworkExtClients(network string) ([]models.ExtClient, error) { var extclients []models.ExtClient - + allextclients := getAllExtClientsFromCache() + if len(allextclients) != 0 { + for _, extclient := range allextclients { + if extclient.Network == network { + extclients = append(extclients, extclient) + } + } + return extclients, nil + } records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) if err != nil { return extclients, err @@ -92,6 +102,10 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) { if err != nil { continue } + key, err := GetRecordKey(extclient.ClientID, network) + if err == nil { + storeExtClientInCache(key, extclient) + } if extclient.Network == network { extclients = append(extclients, extclient) } @@ -106,12 +120,15 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) { if err != nil { return extclient, err } + if extclient, ok := getExtClientFromCache(key); ok { + return extclient, nil + } data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) if err != nil { return extclient, err } err = json.Unmarshal([]byte(data), &extclient) - + storeExtClientInCache(key, extclient) return extclient, err } @@ -190,6 +207,7 @@ func SaveExtClient(extclient *models.ExtClient) error { if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil { return err } + storeExtClientInCache(key, *extclient) return SetNetworkNodesLastModified(extclient.Network) } diff --git a/logic/gateway.go b/logic/gateway.go index 556b108b..083092ee 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -1,7 +1,6 @@ package logic import ( - "encoding/json" "errors" "fmt" "time" @@ -53,11 +52,7 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro node.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled) node.EgressGatewayRequest = gateway // store entire request for use when preserving the egress gateway node.SetLastModified() - nodeData, err := json.Marshal(&node) - if err != nil { - return node, err - } - if err = database.Insert(node.ID.String(), string(nodeData), database.NODES_TABLE_NAME); err != nil { + if err = UpsertNode(&node); err != nil { return models.Node{}, err } return node, nil @@ -84,12 +79,7 @@ func DeleteEgressGateway(network, nodeid string) (models.Node, error) { node.EgressGatewayRanges = []string{} node.EgressGatewayRequest = models.EgressGatewayRequest{} // remove preserved request as the egress gateway is gone node.SetLastModified() - - data, err := json.Marshal(&node) - if err != nil { - return models.Node{}, err - } - if err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil { + if err = UpsertNode(&node); err != nil { return models.Node{}, err } return node, nil @@ -128,11 +118,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq if ingress.Failover && servercfg.Is_EE { node.Failover = true } - data, err := json.Marshal(&node) - if err != nil { - return models.Node{}, err - } - err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) + err = UpsertNode(&node) if err != nil { return models.Node{}, err } @@ -173,12 +159,7 @@ func DeleteIngressGateway(networkName string, nodeid string) (models.Node, bool, node.EgressGatewayRequest.NodeID, node.EgressGatewayRequest.NetID, err)) } } - - data, err := json.Marshal(&node) - if err != nil { - return models.Node{}, false, removedClients, err - } - err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) + err = UpsertNode(&node) if err != nil { return models.Node{}, wasFailover, removedClients, err } diff --git a/logic/hosts.go b/logic/hosts.go index d2ed2f30..06b46fd3 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -10,6 +10,7 @@ import ( "net/http" "sort" "strconv" + "sync" "github.com/devilcove/httpclient" "github.com/google/uuid" @@ -20,6 +21,11 @@ import ( "golang.org/x/crypto/bcrypt" ) +var ( + hostCacheMutex = &sync.RWMutex{} + hostsCacheMap = make(map[string]models.Host) +) + var ( // ErrHostExists error indicating that host exists when trying to create new host ErrHostExists error = errors.New("host already exists") @@ -27,6 +33,46 @@ var ( ErrInvalidHostID error = errors.New("invalid host id") ) +func getHostsFromCache() (hosts []models.Host) { + hostCacheMutex.RLock() + for _, host := range hostsCacheMap { + hosts = append(hosts, host) + } + hostCacheMutex.RUnlock() + return +} + +func getHostsMapFromCache() (hostsMap map[string]models.Host) { + hostCacheMutex.RLock() + hostsMap = hostsCacheMap + hostCacheMutex.RUnlock() + return +} + +func getHostFromCache(hostID string) (host models.Host, ok bool) { + hostCacheMutex.RLock() + host, ok = hostsCacheMap[hostID] + hostCacheMutex.RUnlock() + return +} + +func storeHostInCache(h models.Host) { + hostCacheMutex.Lock() + hostsCacheMap[h.ID.String()] = h + hostCacheMutex.Unlock() +} + +func deleteHostFromCache(hostID string) { + hostCacheMutex.Lock() + delete(hostsCacheMap, hostID) + hostCacheMutex.Unlock() +} +func loadHostsIntoCache(hMap map[string]models.Host) { + hostCacheMutex.Lock() + hostsCacheMap = hMap + hostCacheMutex.Unlock() +} + const ( maxPort = 1<<16 - 1 minPort = 1025 @@ -34,17 +80,28 @@ const ( // GetAllHosts - returns all hosts in flat list or error func GetAllHosts() ([]models.Host, error) { - currHostMap, err := GetHostsMap() - if err != nil { + + currHosts := getHostsFromCache() + if len(currHosts) != 0 { + return currHosts, nil + } + records, err := database.FetchRecords(database.HOSTS_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { return nil, err } - var currentHosts = []models.Host{} - for k := range currHostMap { - var h = *currHostMap[k] - currentHosts = append(currentHosts, h) + currHostsMap := make(map[string]models.Host) + defer loadHostsIntoCache(currHostsMap) + for k := range records { + var h models.Host + err = json.Unmarshal([]byte(records[k]), &h) + if err != nil { + return nil, err + } + currHosts = append(currHosts, h) + currHostsMap[h.ID.String()] = h } - return currentHosts, nil + return currHosts, nil } // GetAllHostsAPI - get's all the hosts in an API usable format @@ -58,19 +115,24 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost { } // GetHostsMap - gets all the current hosts on machine in a map -func GetHostsMap() (map[string]*models.Host, error) { +func GetHostsMap() (map[string]models.Host, error) { + hostsMap := getHostsMapFromCache() + if len(hostsMap) != 0 { + return hostsMap, nil + } records, err := database.FetchRecords(database.HOSTS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return nil, err } - currHostMap := make(map[string]*models.Host) + currHostMap := make(map[string]models.Host) + defer loadHostsIntoCache(currHostMap) for k := range records { var h models.Host err = json.Unmarshal([]byte(records[k]), &h) if err != nil { return nil, err } - currHostMap[h.ID.String()] = &h + currHostMap[h.ID.String()] = h } return currHostMap, nil @@ -78,6 +140,10 @@ func GetHostsMap() (map[string]*models.Host, error) { // GetHost - gets a host from db given id func GetHost(hostid string) (*models.Host, error) { + + if host, ok := getHostFromCache(hostid); ok { + return &host, nil + } record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid) if err != nil { return nil, err @@ -87,7 +153,7 @@ func GetHost(hostid string) (*models.Host, error) { if err = json.Unmarshal([]byte(record), &h); err != nil { return nil, err } - + storeHostInCache(h) return &h, nil } @@ -221,8 +287,12 @@ func UpsertHost(h *models.Host) error { if err != nil { return err } - - return database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME) + err = database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME) + if err != nil { + return err + } + storeHostInCache(*h) + return nil } // RemoveHost - removes a given host from server @@ -233,8 +303,12 @@ func RemoveHost(h *models.Host) error { if servercfg.IsUsingTurn() { DeRegisterHostWithTurn(h.ID.String()) } - - return database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String()) + err := database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String()) + if err != nil { + return err + } + deleteHostFromCache(h.ID.String()) + return nil } // RemoveHostByID - removes a given host by id from server @@ -242,7 +316,13 @@ func RemoveHostByID(hostID string) error { if servercfg.IsUsingTurn() { DeRegisterHostWithTurn(hostID) } - return database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID) + + err := database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID) + if err != nil { + return err + } + deleteHostFromCache(hostID) + return nil } // UpdateHostNetwork - adds/deletes host from a network diff --git a/logic/networks.go b/logic/networks.go index ecdce22a..102c39d3 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -115,24 +115,8 @@ func CreateNetwork(network models.Network) (models.Network, error) { // GetNetworkNonServerNodeCount - get number of network non server nodes func GetNetworkNonServerNodeCount(networkName string) (int, error) { - - collection, err := database.FetchRecords(database.NODES_TABLE_NAME) - count := 0 - if err != nil && !database.IsEmptyRecord(err) { - return count, err - } - for _, value := range collection { - var node models.Node - if err = json.Unmarshal([]byte(value), &node); err != nil { - return count, err - } else { - if node.Network == networkName { - count++ - } - } - } - - return count, nil + nodes, err := GetNetworkNodes(networkName) + return len(nodes), err } // GetParentNetwork - get parent network @@ -210,18 +194,12 @@ func UniqueAddress(networkName string, reverse bool) (net.IP, error) { func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool { isunique := true - collection, err := database.FetchRecords(tableName) - if err != nil { - return isunique - } - - for _, value := range collection { // filter - - if tableName == database.NODES_TABLE_NAME { - var node models.Node - if err = json.Unmarshal([]byte(value), &node); err != nil { - continue - } + if tableName == database.NODES_TABLE_NAME { + nodes, err := GetNetworkNodes(network) + if err != nil { + return isunique + } + for _, node := range nodes { if isIpv6 { if node.Address6.IP.String() == ip && node.Network == network { return false @@ -231,11 +209,15 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool { return false } } - } else if tableName == database.EXT_CLIENT_TABLE_NAME { - var extClient models.ExtClient - if err = json.Unmarshal([]byte(value), &extClient); err != nil { - continue - } + } + + } else if tableName == database.EXT_CLIENT_TABLE_NAME { + + extClients, err := GetNetworkExtClients(network) + if err != nil { + return isunique + } + for _, extClient := range extClients { // filter if isIpv6 { if (extClient.Address6 == ip) && extClient.Network == network { return false @@ -247,7 +229,6 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool { } } } - } return isunique @@ -298,149 +279,6 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) { return add, errors.New("ERROR: No unique IPv6 addresses available. Check network subnet") } -// UpdateNetworkLocalAddresses - updates network localaddresses -func UpdateNetworkLocalAddresses(networkName string) error { - - collection, err := database.FetchRecords(database.NODES_TABLE_NAME) - - if err != nil { - return err - } - - for _, value := range collection { - - var node models.Node - - err := json.Unmarshal([]byte(value), &node) - if err != nil { - fmt.Println("error in node address assignment!") - return err - } - if node.Network == networkName { - var ipaddr net.IP - var iperr error - ipaddr, iperr = UniqueAddress(networkName, false) - if iperr != nil { - fmt.Println("error in node address assignment!") - return iperr - } - - node.Address.IP = ipaddr - newNodeData, err := json.Marshal(&node) - if err != nil { - logger.Log(1, "error in node address assignment!") - return err - } - database.Insert(node.ID.String(), string(newNodeData), database.NODES_TABLE_NAME) - } - } - - return nil -} - -// RemoveNetworkNodeIPv6Addresses - removes network node IPv6 addresses -func RemoveNetworkNodeIPv6Addresses(networkName string) error { - - collections, err := database.FetchRecords(database.NODES_TABLE_NAME) - if err != nil { - return err - } - - for _, value := range collections { - - var node models.Node - err := json.Unmarshal([]byte(value), &node) - if err != nil { - fmt.Println("error in node address assignment!") - return err - } - if node.Network == networkName { - node.Address6.IP = nil - data, err := json.Marshal(&node) - if err != nil { - return err - } - database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) - } - } - - return nil -} - -// UpdateNetworkNodeAddresses - updates network node addresses -func UpdateNetworkNodeAddresses(networkName string) error { - - collections, err := database.FetchRecords(database.NODES_TABLE_NAME) - if err != nil { - return err - } - - for _, value := range collections { - - var node models.Node - err := json.Unmarshal([]byte(value), &node) - if err != nil { - logger.Log(1, "error in node ipv4 address assignment!") - return err - } - if node.Network == networkName { - var ipaddr net.IP - var iperr error - ipaddr, iperr = UniqueAddress(networkName, false) - if iperr != nil { - logger.Log(1, "error in node ipv4 address assignment!") - return iperr - } - - node.Address.IP = ipaddr - data, err := json.Marshal(&node) - if err != nil { - return err - } - database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) - } - } - - return nil -} - -// UpdateNetworkNodeAddresses6 - updates network node addresses -func UpdateNetworkNodeAddresses6(networkName string) error { - - collections, err := database.FetchRecords(database.NODES_TABLE_NAME) - if err != nil { - return err - } - - for _, value := range collections { - - var node models.Node - err := json.Unmarshal([]byte(value), &node) - if err != nil { - logger.Log(1, "error in node ipv6 address assignment!") - return err - } - if node.Network == networkName { - var ipaddr net.IP - var iperr error - ipaddr, iperr = UniqueAddress6(networkName, false) - if iperr != nil { - logger.Log(1, "error in node ipv6 address assignment!") - return iperr - } - - node.Address6.IP = ipaddr - data, err := json.Marshal(&node) - if err != nil { - return err - } - database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) - } - } - - return nil -} - // IsNetworkNameUnique - checks to see if any other networks have the same name (id) func IsNetworkNameUnique(network *models.Network) (bool, error) { diff --git a/logic/nodes.go b/logic/nodes.go index 1becdb64..c0d79f35 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sort" + "sync" "time" validator "github.com/go-playground/validator/v10" @@ -17,11 +18,53 @@ import ( "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/logic/pro/proacls" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/validation" ) +var ( + nodeCacheMutex = &sync.RWMutex{} + nodesCacheMap = make(map[string]models.Node) +) + +func getNodeFromCache(nodeID string) (node models.Node, ok bool) { + nodeCacheMutex.RLock() + node, ok = nodesCacheMap[nodeID] + nodeCacheMutex.RUnlock() + return +} +func getNodesFromCache() (nodes []models.Node) { + nodeCacheMutex.RLock() + for _, node := range nodesCacheMap { + nodes = append(nodes, node) + } + nodeCacheMutex.RUnlock() + return +} + +func deleteNodeFromCache(nodeID string) { + nodeCacheMutex.Lock() + delete(nodesCacheMap, nodeID) + nodeCacheMutex.Unlock() +} + +func storeNodeInCache(node models.Node) { + nodeCacheMutex.Lock() + nodesCacheMap[node.ID.String()] = node + nodeCacheMutex.Unlock() +} + +func loadNodesIntoCache(nMap map[string]models.Node) { + nodeCacheMutex.Lock() + nodesCacheMap = nMap + nodeCacheMutex.Unlock() +} +func ClearNodeCache() { + nodeCacheMutex.Lock() + nodesCacheMap = make(map[string]models.Node) + nodeCacheMutex.Unlock() +} + const ( // RELAY_NODE_ERR - error to return if relay node is unfound RELAY_NODE_ERR = "could not find relay for node" @@ -72,7 +115,12 @@ func UpdateNodeCheckin(node *models.Node) error { if err != nil { return err } - return database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) + err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) + if err != nil { + return err + } + storeNodeInCache(*node) + return nil } // UpsertNode - updates node in the DB @@ -82,7 +130,12 @@ func UpsertNode(newNode *models.Node) error { if err != nil { return err } - return database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) + err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) + if err != nil { + return err + } + storeNodeInCache(*newNode) + return nil } // UpdateNode - takes a node and updates another node with it's values @@ -114,7 +167,12 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { if data, err := json.Marshal(newNode); err != nil { return err } else { - return database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) + err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) + if err != nil { + return err + } + storeNodeInCache(*newNode) + return nil } } @@ -172,6 +230,7 @@ func deleteNodeByID(node *models.Node) error { return err } } + deleteNodeFromCache(node.ID.String()) if servercfg.IsDNSMode() { SetDNS() } @@ -237,7 +296,12 @@ func IsFailoverPresent(network string) bool { // GetAllNodes - returns all nodes in the DB func GetAllNodes() ([]models.Node, error) { var nodes []models.Node - + nodes = getNodesFromCache() + if len(nodes) != 0 { + return nodes, nil + } + nodesMap := make(map[string]models.Node) + defer loadNodesIntoCache(nodesMap) collection, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { if database.IsEmptyRecord(err) { @@ -255,6 +319,7 @@ func GetAllNodes() ([]models.Node, error) { } // add node to our array nodes = append(nodes, node) + nodesMap[node.ID.String()] = node } return nodes, nil @@ -309,46 +374,10 @@ func GetRecordKey(id string, network string) (string, error) { return id + "###" + network, nil } -// GetNodesByAddress - gets a node by mac address -func GetNodesByAddress(network string, addresses []string) ([]models.Node, error) { - var nodes []models.Node - allnodes, err := GetAllNodes() - if err != nil { - return []models.Node{}, err - } - for _, node := range allnodes { - if node.Network == network && ncutils.StringSliceContains(addresses, node.Address.String()) { - nodes = append(nodes, node) - } - } - return nodes, nil -} - -// GetDeletedNodeByMacAddress - get a deleted node -func GetDeletedNodeByMacAddress(network string, macaddress string) (models.Node, error) { - - var node models.Node - - key, err := GetRecordKey(macaddress, network) - if err != nil { - return node, err - } - - record, err := database.FetchRecord(database.DELETED_NODES_TABLE_NAME, key) - if err != nil { - return models.Node{}, err - } - - if err = json.Unmarshal([]byte(record), &node); err != nil { - return models.Node{}, err - } - - SetNodeDefaults(&node) - - return node, nil -} - func GetNodeByID(uuid string) (models.Node, error) { + if node, ok := getNodeFromCache(uuid); ok { + return node, nil + } var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid) if err != nil { return models.Node{}, err @@ -357,6 +386,7 @@ func GetNodeByID(uuid string) (models.Node, error) { if err = json.Unmarshal([]byte(record), &node); err != nil { return models.Node{}, err } + storeNodeInCache(node) return node, nil } @@ -506,7 +536,7 @@ func createNode(node *models.Node) error { if err != nil { return err } - + storeNodeInCache(*node) _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal) if err != nil { logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error()) diff --git a/logic/relay.go b/logic/relay.go index 03b7eb33..4782c156 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -1,12 +1,10 @@ package logic import ( - "encoding/json" "errors" "fmt" "net" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" ) @@ -33,25 +31,11 @@ func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error) node.IsRelay = true node.RelayedNodes = relay.RelayedNodes node.SetLastModified() - nodeData, err := json.Marshal(&node) + err = UpsertNode(&node) if err != nil { return returnnodes, node, err } - if err = database.Insert(node.ID.String(), string(nodeData), database.NODES_TABLE_NAME); err != nil { - return returnnodes, models.Node{}, err - } returnnodes = SetRelayedNodes(true, relay.NodeID, relay.RelayedNodes) - for _, relayedNode := range returnnodes { - data, err := json.Marshal(&relayedNode) - if err != nil { - logger.Log(0, "marshalling relayed node", err.Error()) - continue - } - if err := database.Insert(relayedNode.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil { - logger.Log(0, "inserting relayed node", err.Error()) - continue - } - } return returnnodes, node, nil } @@ -71,12 +55,7 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N node.RelayedBy = "" } node.SetLastModified() - data, err := json.Marshal(&node) - if err != nil { - logger.Log(0, "setRelayedNodes.Marshal", err.Error()) - continue - } - if err := database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil { + if err := UpsertNode(&node); err != nil { logger.Log(0, "setRelayedNodes.Insert", err.Error()) continue } @@ -145,11 +124,7 @@ func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) { node.IsRelay = false node.RelayedNodes = []string{} node.SetLastModified() - data, err := json.Marshal(&node) - if err != nil { - return returnnodes, models.Node{}, err - } - if err = database.Insert(nodeid, string(data), database.NODES_TABLE_NAME); err != nil { + if err = UpsertNode(&node); err != nil { return returnnodes, models.Node{}, err } return returnnodes, node, nil diff --git a/mq/mq.go b/mq/mq.go index 5e6fdb72..9be06871 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -80,7 +80,7 @@ func SetupMQTT() { logger.Log(0, "node metrics subscription failed") } - opts.SetOrderMatters(true) + opts.SetOrderMatters(false) opts.SetResumeSubs(true) }) mqclient = mqtt.NewClient(opts)