diff --git a/config/config.go b/config/config.go index bf22666c..5150f1ce 100644 --- a/config/config.go +++ b/config/config.go @@ -91,6 +91,7 @@ type ServerConfig struct { Environment string `yaml:"environment"` JwtValidityDuration time.Duration `yaml:"jwt_validity_duration"` RacAutoDisable bool `yaml:"rac_auto_disable"` + CacheEnabled bool `yaml:"caching_enabled"` } // SQLConfig - Generic SQL Config diff --git a/controllers/node_test.go b/controllers/node_test.go index ba877276..f6875b85 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -10,6 +10,7 @@ import ( "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/logic/acls/nodeacls" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -217,7 +218,9 @@ func TestNodeACLs(t *testing.T) { } func deleteAllNodes() { - logic.ClearNodeCache() + if servercfg.CacheEnabled() { + logic.ClearNodeCache() + } database.DeleteAllRecords(database.NODES_TABLE_NAME) } diff --git a/logic/acls/common.go b/logic/acls/common.go index 491d574c..d43c8186 100644 --- a/logic/acls/common.go +++ b/logic/acls/common.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" ) @@ -128,8 +129,10 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err func fetchACLContainer(containerID ContainerID) (ACLContainer, error) { aclMutex.RLock() defer aclMutex.RUnlock() - if aclContainer, ok := fetchAclContainerFromCache(containerID); ok { - return aclContainer, nil + if servercfg.CacheEnabled() { + if aclContainer, ok := fetchAclContainerFromCache(containerID); ok { + return aclContainer, nil + } } aclJson, err := fetchACLContainerJson(ContainerID(containerID)) if err != nil { @@ -139,7 +142,9 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) { if err := json.Unmarshal([]byte(aclJson), ¤tNetworkACL); err != nil { return nil, err } - storeAclContainerInCache(containerID, currentNetworkACL) + if servercfg.CacheEnabled() { + storeAclContainerInCache(containerID, currentNetworkACL) + } return currentNetworkACL, nil } @@ -176,7 +181,9 @@ func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACL if err != nil { return aclContainer, err } - storeAclContainerInCache(containerID, aclContainer) + if servercfg.CacheEnabled() { + storeAclContainerInCache(containerID, aclContainer) + } return aclContainer, nil } diff --git a/logic/acls/nodeacls/modify.go b/logic/acls/nodeacls/modify.go index e803bb65..0beb1b0b 100644 --- a/logic/acls/nodeacls/modify.go +++ b/logic/acls/nodeacls/modify.go @@ -3,6 +3,7 @@ package nodeacls import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/servercfg" ) // CreateNodeACL - inserts or updates a node ACL on given network and adds to state @@ -87,6 +88,8 @@ func DeleteACLContainer(network NetworkID) error { if err != nil { return err } - acls.DeleteAclFromCache(acls.ContainerID(network)) + if servercfg.CacheEnabled() { + acls.DeleteAclFromCache(acls.ContainerID(network)) + } return nil } diff --git a/logic/extpeers.go b/logic/extpeers.go index aa6b715c..6c2a2a69 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -11,6 +11,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -80,21 +81,25 @@ func DeleteExtClient(network string, clientid string) error { if err != nil { return err } - deleteExtClientFromCache(key) + if servercfg.CacheEnabled() { + deleteExtClientFromCache(key) + } return nil } // GetNetworkExtClients - gets the ext clients of given network func GetNetworkExtClients(network string) ([]models.ExtClient, error) { var extclients []models.ExtClient - allextclients := getAllExtClientsFromCache() - if len(allextclients) != 0 { - for _, extclient := range allextclients { - if extclient.Network == network { - extclients = append(extclients, extclient) + if servercfg.CacheEnabled() { + allextclients := getAllExtClientsFromCache() + if len(allextclients) != 0 { + for _, extclient := range allextclients { + if extclient.Network == network { + extclients = append(extclients, extclient) + } } + return extclients, nil } - return extclients, nil } records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) if err != nil { @@ -111,7 +116,9 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) { } key, err := GetRecordKey(extclient.ClientID, extclient.Network) if err == nil { - storeExtClientInCache(key, extclient) + if servercfg.CacheEnabled() { + storeExtClientInCache(key, extclient) + } } if extclient.Network == network { extclients = append(extclients, extclient) @@ -127,15 +134,19 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) { if err != nil { return extclient, err } - if extclient, ok := getExtClientFromCache(key); ok { - return extclient, nil + if servercfg.CacheEnabled() { + if extclient, ok := getExtClientFromCache(key); ok { + return extclient, nil + } } data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) if err != nil { return extclient, err } err = json.Unmarshal([]byte(data), &extclient) - storeExtClientInCache(key, extclient) + if servercfg.CacheEnabled() { + storeExtClientInCache(key, extclient) + } return extclient, err } @@ -235,7 +246,9 @@ func SaveExtClient(extclient *models.ExtClient) error { if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil { return err } - storeExtClientInCache(key, *extclient) + if servercfg.CacheEnabled() { + storeExtClientInCache(key, *extclient) + } return SetNetworkNodesLastModified(extclient.Network) } diff --git a/logic/hosts.go b/logic/hosts.go index ca260de0..b9f5de26 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -81,16 +81,21 @@ const ( // GetAllHosts - returns all hosts in flat list or error func GetAllHosts() ([]models.Host, error) { - currHosts := getHostsFromCache() - if len(currHosts) != 0 { - return currHosts, nil + var currHosts []models.Host + if servercfg.CacheEnabled() { + currHosts := getHostsFromCache() + if len(currHosts) != 0 { + return currHosts, nil + } } records, err := database.FetchRecords(database.HOSTS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return nil, err } currHostsMap := make(map[string]models.Host) - defer loadHostsIntoCache(currHostsMap) + if servercfg.CacheEnabled() { + defer loadHostsIntoCache(currHostsMap) + } for k := range records { var h models.Host err = json.Unmarshal([]byte(records[k]), &h) @@ -116,16 +121,20 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost { // GetHostsMap - gets all the current hosts on machine in a map func GetHostsMap() (map[string]models.Host, error) { - hostsMap := getHostsMapFromCache() - if len(hostsMap) != 0 { - return hostsMap, nil + if servercfg.CacheEnabled() { + hostsMap := getHostsMapFromCache() + if len(hostsMap) != 0 { + return hostsMap, nil + } } records, err := database.FetchRecords(database.HOSTS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return nil, err } currHostMap := make(map[string]models.Host) - defer loadHostsIntoCache(currHostMap) + if servercfg.CacheEnabled() { + defer loadHostsIntoCache(currHostMap) + } for k := range records { var h models.Host err = json.Unmarshal([]byte(records[k]), &h) @@ -140,8 +149,10 @@ func GetHostsMap() (map[string]models.Host, error) { // GetHost - gets a host from db given id func GetHost(hostid string) (*models.Host, error) { - if host, ok := getHostFromCache(hostid); ok { - return &host, nil + if servercfg.CacheEnabled() { + if host, ok := getHostFromCache(hostid); ok { + return &host, nil + } } record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid) if err != nil { @@ -152,7 +163,10 @@ func GetHost(hostid string) (*models.Host, error) { if err = json.Unmarshal([]byte(record), &h); err != nil { return nil, err } - storeHostInCache(h) + if servercfg.CacheEnabled() { + storeHostInCache(h) + } + return &h, nil } @@ -279,7 +293,10 @@ func UpsertHost(h *models.Host) error { if err != nil { return err } - storeHostInCache(*h) + if servercfg.CacheEnabled() { + storeHostInCache(*h) + } + return nil } @@ -303,8 +320,10 @@ func RemoveHost(h *models.Host, forceDelete bool) error { if err != nil { return err } + if servercfg.CacheEnabled() { + deleteHostFromCache(h.ID.String()) + } - deleteHostFromCache(h.ID.String()) return nil } @@ -318,7 +337,9 @@ func RemoveHostByID(hostID string) error { if err != nil { return err } - deleteHostFromCache(hostID) + if servercfg.CacheEnabled() { + deleteHostFromCache(hostID) + } return nil } diff --git a/logic/nodes.go b/logic/nodes.go index a72043db..17d04a2d 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -119,7 +119,9 @@ func UpdateNodeCheckin(node *models.Node) error { if err != nil { return err } - storeNodeInCache(*node) + if servercfg.CacheEnabled() { + storeNodeInCache(*node) + } return nil } @@ -134,7 +136,9 @@ func UpsertNode(newNode *models.Node) error { if err != nil { return err } - storeNodeInCache(*newNode) + if servercfg.CacheEnabled() { + storeNodeInCache(*newNode) + } return nil } @@ -171,7 +175,9 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { if err != nil { return err } - storeNodeInCache(*newNode) + if servercfg.CacheEnabled() { + storeNodeInCache(*newNode) + } return nil } } @@ -264,7 +270,9 @@ func DeleteNodeByID(node *models.Node) error { return err } } - deleteNodeFromCache(node.ID.String()) + if servercfg.CacheEnabled() { + deleteNodeFromCache(node.ID.String()) + } if servercfg.IsDNSMode() { SetDNS() } @@ -310,12 +318,16 @@ func ValidateNode(node *models.Node, isUpdate bool) error { // GetAllNodes - returns all nodes in the DB func GetAllNodes() ([]models.Node, error) { var nodes []models.Node - nodes = getNodesFromCache() - if len(nodes) != 0 { - return nodes, nil + if servercfg.CacheEnabled() { + nodes = getNodesFromCache() + if len(nodes) != 0 { + return nodes, nil + } } nodesMap := make(map[string]models.Node) - defer loadNodesIntoCache(nodesMap) + if servercfg.CacheEnabled() { + defer loadNodesIntoCache(nodesMap) + } collection, err := database.FetchRecords(database.NODES_TABLE_NAME) if err != nil { if database.IsEmptyRecord(err) { @@ -389,8 +401,10 @@ func GetRecordKey(id string, network string) (string, error) { } func GetNodeByID(uuid string) (models.Node, error) { - if node, ok := getNodeFromCache(uuid); ok { - return node, nil + if servercfg.CacheEnabled() { + if node, ok := getNodeFromCache(uuid); ok { + return node, nil + } } var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid) if err != nil { @@ -400,7 +414,9 @@ func GetNodeByID(uuid string) (models.Node, error) { if err = json.Unmarshal([]byte(record), &node); err != nil { return models.Node{}, err } - storeNodeInCache(node) + if servercfg.CacheEnabled() { + storeNodeInCache(node) + } return node, nil } @@ -556,7 +572,9 @@ func createNode(node *models.Node) error { if err != nil { return err } - storeNodeInCache(*node) + if servercfg.CacheEnabled() { + storeNodeInCache(*node) + } _, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal) if err != nil { logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error()) diff --git a/scripts/netmaker.default.env b/scripts/netmaker.default.env index 247791d4..ade0e5c8 100644 --- a/scripts/netmaker.default.env +++ b/scripts/netmaker.default.env @@ -81,3 +81,5 @@ OIDC_ISSUER= JWT_VALIDITY_DURATION=43200 # Auto disable a user's connecteds clients bassed on JWT token expiration RAC_AUTO_DISABLE="true" +# if turned on data will be cached on to improve performance significantly (IMPORTANT: If HA set to `false` ) +CACHING_ENABLED="true diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index d032e6bf..73d23ba1 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -207,6 +207,17 @@ func GetDB() string { return database } +// CacheEnabled - checks if cache is enabled +func CacheEnabled() bool { + caching := false + if os.Getenv("CACHING_ENABLED") != "" { + caching = os.Getenv("CACHING_ENABLED") == "true" + } else if config.Config.Server.Database != "" { + caching = config.Config.Server.CacheEnabled + } + return caching +} + // GetAPIHost - gets the api host func GetAPIHost() string { serverhost := "127.0.0.1"