Merge pull request #3402 from gravitl/rules_list_fix_v0.90

Rules list fix v0.90
This commit is contained in:
Abhishek K 2025-03-29 02:08:32 +04:00 committed by GitHub
commit e80a277e9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 66 additions and 24 deletions

View file

@ -1321,8 +1321,12 @@ func getUserAclRulesForNode(targetnode *models.Node,
if aclRule, ok := rules[acl.ID]; ok {
aclRule.IPList = append(aclRule.IPList, r.IPList...)
aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
aclRule.IPList = UniqueIPNetList(aclRule.IPList)
aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
rules[acl.ID] = aclRule
} else {
r.IPList = UniqueIPNetList(r.IPList)
r.IP6List = UniqueIPNetList(r.IP6List)
rules[acl.ID] = r
}
}
@ -1598,24 +1602,6 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu
}
return rules
}
func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
uniqueMap := make(map[string]net.IPNet)
for _, ipnet := range ipnets {
key := ipnet.String() // Uses CIDR notation as a unique key
if _, exists := uniqueMap[key]; !exists {
uniqueMap[key] = ipnet
}
}
// Convert map back to slice
uniqueList := make([]net.IPNet, 0, len(uniqueMap))
for _, ipnet := range uniqueMap {
uniqueList = append(uniqueList, ipnet)
}
return uniqueList
}
func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclRule) {
rules = make(map[string]models.AclRule)
@ -1831,3 +1817,44 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
}
return
}
// Compare two IPs and return true if ip1 < ip2
func lessIP(ip1, ip2 net.IP) bool {
ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
ip2 = ip2.To16()
return string(ip1) < string(ip2)
}
// Sort by IP first, then by prefix length
func sortIPNets(ipNets []net.IPNet) {
sort.Slice(ipNets, func(i, j int) bool {
ip1, ip2 := ipNets[i].IP, ipNets[j].IP
mask1, _ := ipNets[i].Mask.Size()
mask2, _ := ipNets[j].Mask.Size()
// Compare IPs first
if ip1.Equal(ip2) {
return mask1 < mask2 // If same IP, sort by subnet mask size
}
return lessIP(ip1, ip2)
})
}
func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
uniqueMap := make(map[string]net.IPNet)
for _, ipnet := range ipnets {
key := ipnet.String() // Uses CIDR notation as a unique key
if _, exists := uniqueMap[key]; !exists {
uniqueMap[key] = ipnet
}
}
// Convert map back to slice
uniqueList := make([]net.IPNet, 0, len(uniqueMap))
for _, ipnet := range uniqueMap {
uniqueList = append(uniqueList, ipnet)
}
sortIPNets(uniqueList)
return uniqueList
}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net"
"reflect"
"sort"
"strings"
"sync"
"time"
@ -464,7 +465,18 @@ func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.
return newClient, nil
}
// Sort a slice of net.IP addresses
func sortIPs(ips []net.IP) {
sort.Slice(ips, func(i, j int) bool {
ip1, ip2 := ips[i].To16(), ips[j].To16()
return string(ip1) < string(ip2) // Compare as byte slices
})
}
func GetStaticNodeIps(node models.Node) (ips []net.IP) {
defer func() {
sortIPs(ips)
}()
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@ -731,7 +743,14 @@ func getFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []
func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
// fetch user access to static clients via policies
defer func() {
sort.Slice(rules, func(i, j int) bool {
if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
}
return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
})
}()
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
nodes, _ := GetNetworkNodes(node.Network)
nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)

View file

@ -706,6 +706,7 @@ func createNode(node *models.Node) error {
if err != nil {
return err
}
node.SetLastCheckIn()
err = database.Insert(node.ID.String(), string(nodebytes), database.NODES_TABLE_NAME)
if err != nil {
return err

View file

@ -126,11 +126,6 @@ func initialize() { // Client Mode Prereq Check
}
}
if servercfg.IsMessageQueueBackend() {
if err = mq.ServerStartNotify(); err != nil {
logger.Log(0, "error occurred when notifying nodes of startup", err.Error())
}
}
}
func startControllers(wg *sync.WaitGroup, ctx context.Context) {