sort the acl rules

This commit is contained in:
abhishek9686 2025-03-27 23:54:38 +04:00
parent 75307cb726
commit c318c939f4
2 changed files with 68 additions and 1 deletions

View file

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"maps"
"net"
"sort"
"sync"
"time"
@ -950,8 +951,12 @@ func getUserAclRulesForNode(targetnode *models.Node,
if aclRule, ok := rules[acl.ID]; ok {
aclRule.IPList = append(aclRule.IPList, r.IPList...)
aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
aclRule.IPList = UniqueIPNetList(aclRule.IPList)
aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
rules[acl.ID] = aclRule
} else {
r.IPList = UniqueIPNetList(r.IPList)
r.IP6List = UniqueIPNetList(r.IP6List)
rules[acl.ID] = r
}
}
@ -1115,9 +1120,52 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
}
}
if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
aclRule.IPList = UniqueIPNetList(aclRule.IPList)
aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
rules[acl.ID] = aclRule
}
}
}
return rules
}
// Compare two IPs and return true if ip1 < ip2
func lessIP(ip1, ip2 net.IP) bool {
ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
ip2 = ip2.To16()
return string(ip1) < string(ip2)
}
// Sort by IP first, then by prefix length
func sortIPNets(ipNets []net.IPNet) {
sort.Slice(ipNets, func(i, j int) bool {
ip1, ip2 := ipNets[i].IP, ipNets[j].IP
mask1, _ := ipNets[i].Mask.Size()
mask2, _ := ipNets[j].Mask.Size()
// Compare IPs first
if ip1.Equal(ip2) {
return mask1 < mask2 // If same IP, sort by subnet mask size
}
return lessIP(ip1, ip2)
})
}
func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
uniqueMap := make(map[string]net.IPNet)
for _, ipnet := range ipnets {
key := ipnet.String() // Uses CIDR notation as a unique key
if _, exists := uniqueMap[key]; !exists {
uniqueMap[key] = ipnet
}
}
// Convert map back to slice
uniqueList := make([]net.IPNet, 0, len(uniqueMap))
for _, ipnet := range uniqueMap {
uniqueList = append(uniqueList, ipnet)
}
sortIPNets(uniqueList)
return uniqueList
}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net"
"reflect"
"sort"
"strings"
"sync"
"time"
@ -463,7 +464,18 @@ func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.
return newClient, nil
}
// Sort a slice of net.IP addresses
func sortIPs(ips []net.IP) {
sort.Slice(ips, func(i, j int) bool {
ip1, ip2 := ips[i].To16(), ips[j].To16()
return string(ip1) < string(ip2) // Compare as byte slices
})
}
func GetStaticNodeIps(node models.Node) (ips []net.IP) {
defer func() {
sortIPs(ips)
}()
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@ -487,7 +499,14 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
// fetch user access to static clients via policies
defer func() {
sort.Slice(rules, func(i, j int) bool {
if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
}
return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
})
}()
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
nodes, _ := GetNetworkNodes(node.Network)