Merge pull request #3340 from gravitl/NET-1964

NET-1964: add node mutex to model
This commit is contained in:
Abhishek K 2025-02-27 12:39:58 +04:00 committed by GitHub
commit 2f0d289813
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 120 additions and 36 deletions

View file

@ -17,7 +17,6 @@ import (
var ( var (
aclCacheMutex = &sync.RWMutex{} aclCacheMutex = &sync.RWMutex{}
aclCacheMap = make(map[string]models.Acl) aclCacheMap = make(map[string]models.Acl)
aclTagsMutex = &sync.RWMutex{}
) )
func MigrateAclPolicies() { func MigrateAclPolicies() {
@ -576,10 +575,22 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
if peer.IsStatic { if peer.IsStatic {
peer = peer.StaticNode.ConvertToStaticNode() peer = peer.StaticNode.ConvertToStaticNode()
} }
aclTagsMutex.RLock() var nodeTags, peerTags map[models.TagID]struct{}
peerTags := maps.Clone(peer.Tags) if node.Mutex != nil {
nodeTags := maps.Clone(node.Tags) node.Mutex.Lock()
aclTagsMutex.RUnlock() nodeTags = maps.Clone(node.Tags)
node.Mutex.Unlock()
} else {
nodeTags = node.Tags
}
if peer.Mutex != nil {
peer.Mutex.Lock()
peerTags = maps.Clone(peer.Tags)
peer.Mutex.Unlock()
} else {
peerTags = peer.Tags
}
if checkDefaultPolicy { if checkDefaultPolicy {
// check default policy if all allowed return true // check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@ -661,10 +672,21 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
if peer.IsStatic { if peer.IsStatic {
peer = peer.StaticNode.ConvertToStaticNode() peer = peer.StaticNode.ConvertToStaticNode()
} }
aclTagsMutex.RLock() var nodeTags, peerTags map[models.TagID]struct{}
peerTags := maps.Clone(peer.Tags) if node.Mutex != nil {
nodeTags := maps.Clone(node.Tags) node.Mutex.Lock()
aclTagsMutex.RUnlock() nodeTags = maps.Clone(node.Tags)
node.Mutex.Unlock()
} else {
nodeTags = node.Tags
}
if peer.Mutex != nil {
peer.Mutex.Lock()
peerTags = maps.Clone(peer.Tags)
peer.Mutex.Unlock()
} else {
peerTags = peer.Tags
}
if checkDefaultPolicy { if checkDefaultPolicy {
// check default policy if all allowed return true // check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy) defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@ -862,7 +884,15 @@ func getUserAclRulesForNode(targetnode *models.Node,
userGrpMap := GetUserGrpMap() userGrpMap := GetUserGrpMap()
allowedUsers := make(map[string][]models.Acl) allowedUsers := make(map[string][]models.Acl)
acls := listUserPolicies(models.NetworkID(targetnode.Network)) acls := listUserPolicies(models.NetworkID(targetnode.Network))
for nodeTag := range targetnode.Tags { var targetNodeTags = make(map[models.TagID]struct{})
if targetnode.Mutex != nil {
targetnode.Mutex.Lock()
targetNodeTags = maps.Clone(targetnode.Tags)
targetnode.Mutex.Unlock()
} else {
targetNodeTags = maps.Clone(targetnode.Tags)
}
for nodeTag := range targetNodeTags {
for _, acl := range acls { for _, acl := range acls {
if !acl.Enabled { if !acl.Enabled {
continue continue
@ -886,6 +916,7 @@ func getUserAclRulesForNode(targetnode *models.Node,
} }
} }
} }
for _, userNode := range userNodes { for _, userNode := range userNodes {
if !userNode.StaticNode.Enabled { if !userNode.StaticNode.Enabled {
continue continue
@ -942,8 +973,17 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
} }
acls := listDevicePolicies(models.NetworkID(targetnode.Network)) acls := listDevicePolicies(models.NetworkID(targetnode.Network))
targetnode.Tags["*"] = struct{}{}
for nodeTag := range targetnode.Tags { var targetNodeTags = make(map[models.TagID]struct{})
if targetnode.Mutex != nil {
targetnode.Mutex.Lock()
targetNodeTags = maps.Clone(targetnode.Tags)
targetnode.Mutex.Unlock()
} else {
targetNodeTags = maps.Clone(targetnode.Tags)
}
targetNodeTags["*"] = struct{}{}
for nodeTag := range targetNodeTags {
for _, acl := range acls { for _, acl := range acls {
if !acl.Enabled { if !acl.Enabled {
continue continue

View file

@ -28,6 +28,9 @@ var (
func getAllExtClientsFromCache() (extClients []models.ExtClient) { func getAllExtClientsFromCache() (extClients []models.ExtClient) {
extClientCacheMutex.RLock() extClientCacheMutex.RLock()
for _, extclient := range extClientCacheMap { for _, extclient := range extClientCacheMap {
if extclient.Mutex == nil {
extclient.Mutex = &sync.Mutex{}
}
extClients = append(extClients, extclient) extClients = append(extClients, extclient)
} }
extClientCacheMutex.RUnlock() extClientCacheMutex.RUnlock()
@ -43,12 +46,18 @@ func deleteExtClientFromCache(key string) {
func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) { func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) {
extClientCacheMutex.RLock() extClientCacheMutex.RLock()
extclient, ok = extClientCacheMap[key] extclient, ok = extClientCacheMap[key]
if extclient.Mutex == nil {
extclient.Mutex = &sync.Mutex{}
}
extClientCacheMutex.RUnlock() extClientCacheMutex.RUnlock()
return return
} }
func storeExtClientInCache(key string, extclient models.ExtClient) { func storeExtClientInCache(key string, extclient models.ExtClient) {
extClientCacheMutex.Lock() extClientCacheMutex.Lock()
if extclient.Mutex == nil {
extclient.Mutex = &sync.Mutex{}
}
extClientCacheMap[key] = extclient extClientCacheMap[key] = extclient
extClientCacheMutex.Unlock() extClientCacheMutex.Unlock()
} }

View file

@ -35,12 +35,18 @@ var (
func getNodeFromCache(nodeID string) (node models.Node, ok bool) { func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
nodeCacheMutex.RLock() nodeCacheMutex.RLock()
node, ok = nodesCacheMap[nodeID] node, ok = nodesCacheMap[nodeID]
if node.Mutex == nil {
node.Mutex = &sync.Mutex{}
}
nodeCacheMutex.RUnlock() nodeCacheMutex.RUnlock()
return return
} }
func getNodesFromCache() (nodes []models.Node) { func getNodesFromCache() (nodes []models.Node) {
nodeCacheMutex.RLock() nodeCacheMutex.RLock()
for _, node := range nodesCacheMap { for _, node := range nodesCacheMap {
if node.Mutex == nil {
node.Mutex = &sync.Mutex{}
}
nodes = append(nodes, node) nodes = append(nodes, node)
} }
nodeCacheMutex.RUnlock() nodeCacheMutex.RUnlock()
@ -425,6 +431,9 @@ func GetAllNodes() ([]models.Node, error) {
} }
// add node to our array // add node to our array
nodes = append(nodes, node) nodes = append(nodes, node)
if node.Mutex == nil {
node.Mutex = &sync.Mutex{}
}
nodesMap[node.ID.String()] = node nodesMap[node.ID.String()] = node
} }
@ -811,9 +820,16 @@ func GetTagMapWithNodes() (tagNodesMap map[models.TagID][]models.Node) {
if nodeI.Tags == nil { if nodeI.Tags == nil {
continue continue
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Lock()
}
for nodeTagID := range nodeI.Tags { for nodeTagID := range nodeI.Tags {
tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI) tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Unlock()
}
} }
return return
} }
@ -825,9 +841,15 @@ func GetTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
if nodeI.Tags == nil { if nodeI.Tags == nil {
continue continue
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Lock()
}
for nodeTagID := range nodeI.Tags { for nodeTagID := range nodeI.Tags {
tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI) tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Unlock()
}
} }
tagNodesMap["*"] = nodes tagNodesMap["*"] = nodes
if !withStaticNodes { if !withStaticNodes {
@ -846,17 +868,16 @@ func AddTagMapWithStaticNodes(netID models.NetworkID,
if extclient.Tags == nil || extclient.RemoteAccessClientID != "" { if extclient.Tags == nil || extclient.RemoteAccessClientID != "" {
continue continue
} }
for tagID := range extclient.Tags { if extclient.Mutex != nil {
tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{ extclient.Mutex.Lock()
IsStatic: true, }
StaticNode: extclient, for tagID := range extclient.Tags {
}) tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode())
tagNodesMap["*"] = append(tagNodesMap["*"], models.Node{ tagNodesMap["*"] = append(tagNodesMap["*"], extclient.ConvertToStaticNode())
IsStatic: true, }
StaticNode: extclient, if extclient.Mutex != nil {
}) extclient.Mutex.Unlock()
} }
} }
return tagNodesMap return tagNodesMap
} }
@ -871,11 +892,14 @@ func AddTagMapWithStaticNodesWithUsers(netID models.NetworkID,
if extclient.Tags == nil { if extclient.Tags == nil {
continue continue
} }
if extclient.Mutex != nil {
extclient.Mutex.Lock()
}
for tagID := range extclient.Tags { for tagID := range extclient.Tags {
tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{ tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode())
IsStatic: true, }
StaticNode: extclient, if extclient.Mutex != nil {
}) extclient.Mutex.Unlock()
} }
} }
@ -893,9 +917,15 @@ func GetNodesWithTag(tagID models.TagID) map[string]models.Node {
if nodeI.Tags == nil { if nodeI.Tags == nil {
continue continue
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Lock()
}
if _, ok := nodeI.Tags[tagID]; ok { if _, ok := nodeI.Tags[tagID]; ok {
nMap[nodeI.ID.String()] = nodeI nMap[nodeI.ID.String()] = nodeI
} }
if nodeI.Mutex != nil {
nodeI.Mutex.Unlock()
}
} }
return AddStaticNodesWithTag(tag, nMap) return AddStaticNodesWithTag(tag, nMap)
} }
@ -909,13 +939,15 @@ func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[stri
if extclient.RemoteAccessClientID != "" { if extclient.RemoteAccessClientID != "" {
continue continue
} }
if _, ok := extclient.Tags[tag.ID]; ok { if extclient.Mutex != nil {
nMap[extclient.ClientID] = models.Node{ extclient.Mutex.Lock()
IsStatic: true, }
StaticNode: extclient, if _, ok := extclient.Tags[tag.ID]; ok {
} nMap[extclient.ClientID] = extclient.ConvertToStaticNode()
}
if extclient.Mutex != nil {
extclient.Mutex.Unlock()
} }
} }
return nMap return nMap
} }
@ -931,10 +963,7 @@ func GetStaticNodeWithTag(tagID models.TagID) map[string]models.Node {
return nMap return nMap
} }
for _, extclient := range extclients { for _, extclient := range extclients {
nMap[extclient.ClientID] = models.Node{ nMap[extclient.ClientID] = extclient.ConvertToStaticNode()
IsStatic: true,
StaticNode: extclient,
}
} }
return nMap return nMap
} }

View file

@ -1,5 +1,7 @@
package models package models
import "sync"
// ExtClient - struct for external clients // ExtClient - struct for external clients
type ExtClient struct { type ExtClient struct {
ClientID string `json:"clientid" bson:"clientid"` ClientID string `json:"clientid" bson:"clientid"`
@ -25,6 +27,7 @@ type ExtClient struct {
DeviceName string `json:"device_name"` DeviceName string `json:"device_name"`
PublicEndpoint string `json:"public_endpoint"` PublicEndpoint string `json:"public_endpoint"`
Country string `json:"country"` Country string `json:"country"`
Mutex *sync.Mutex `json:"-"`
} }
// CustomExtClient - struct for CustomExtClient params // CustomExtClient - struct for CustomExtClient params
@ -55,5 +58,6 @@ func (ext *ExtClient) ConvertToStaticNode() Node {
Tags: ext.Tags, Tags: ext.Tags,
IsStatic: true, IsStatic: true,
StaticNode: *ext, StaticNode: *ext,
Mutex: ext.Mutex,
} }
} }

View file

@ -5,6 +5,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -117,6 +118,7 @@ type Node struct {
IsUserNode bool `json:"is_user_node"` IsUserNode bool `json:"is_user_node"`
StaticNode ExtClient `json:"static_node"` StaticNode ExtClient `json:"static_node"`
Status NodeStatus `json:"node_status"` Status NodeStatus `json:"node_status"`
Mutex *sync.Mutex `json:"-"`
} }
// LegacyNode - legacy struct for node model // LegacyNode - legacy struct for node model