mirror of
https://github.com/gravitl/netmaker.git
synced 2025-11-10 00:30:37 +08:00
sort the acl rules
This commit is contained in:
parent
75307cb726
commit
c318c939f4
2 changed files with 68 additions and 1 deletions
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -950,8 +951,12 @@ func getUserAclRulesForNode(targetnode *models.Node,
|
|||
if aclRule, ok := rules[acl.ID]; ok {
|
||||
aclRule.IPList = append(aclRule.IPList, r.IPList...)
|
||||
aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
|
||||
aclRule.IPList = UniqueIPNetList(aclRule.IPList)
|
||||
aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
|
||||
rules[acl.ID] = aclRule
|
||||
} else {
|
||||
r.IPList = UniqueIPNetList(r.IPList)
|
||||
r.IP6List = UniqueIPNetList(r.IP6List)
|
||||
rules[acl.ID] = r
|
||||
}
|
||||
}
|
||||
|
|
@ -1115,9 +1120,52 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
|
|||
}
|
||||
}
|
||||
if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
|
||||
aclRule.IPList = UniqueIPNetList(aclRule.IPList)
|
||||
aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
|
||||
rules[acl.ID] = aclRule
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
// Compare two IPs and return true if ip1 < ip2
|
||||
func lessIP(ip1, ip2 net.IP) bool {
|
||||
ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
|
||||
ip2 = ip2.To16()
|
||||
return string(ip1) < string(ip2)
|
||||
}
|
||||
|
||||
// Sort by IP first, then by prefix length
|
||||
func sortIPNets(ipNets []net.IPNet) {
|
||||
sort.Slice(ipNets, func(i, j int) bool {
|
||||
ip1, ip2 := ipNets[i].IP, ipNets[j].IP
|
||||
mask1, _ := ipNets[i].Mask.Size()
|
||||
mask2, _ := ipNets[j].Mask.Size()
|
||||
|
||||
// Compare IPs first
|
||||
if ip1.Equal(ip2) {
|
||||
return mask1 < mask2 // If same IP, sort by subnet mask size
|
||||
}
|
||||
return lessIP(ip1, ip2)
|
||||
})
|
||||
}
|
||||
|
||||
func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
|
||||
uniqueMap := make(map[string]net.IPNet)
|
||||
|
||||
for _, ipnet := range ipnets {
|
||||
key := ipnet.String() // Uses CIDR notation as a unique key
|
||||
if _, exists := uniqueMap[key]; !exists {
|
||||
uniqueMap[key] = ipnet
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map back to slice
|
||||
uniqueList := make([]net.IPNet, 0, len(uniqueMap))
|
||||
for _, ipnet := range uniqueMap {
|
||||
uniqueList = append(uniqueList, ipnet)
|
||||
}
|
||||
sortIPNets(uniqueList)
|
||||
return uniqueList
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -463,7 +464,18 @@ func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.
|
|||
return newClient, nil
|
||||
}
|
||||
|
||||
// Sort a slice of net.IP addresses
|
||||
func sortIPs(ips []net.IP) {
|
||||
sort.Slice(ips, func(i, j int) bool {
|
||||
ip1, ip2 := ips[i].To16(), ips[j].To16()
|
||||
return string(ip1) < string(ip2) // Compare as byte slices
|
||||
})
|
||||
}
|
||||
|
||||
func GetStaticNodeIps(node models.Node) (ips []net.IP) {
|
||||
defer func() {
|
||||
sortIPs(ips)
|
||||
}()
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
|
||||
|
|
@ -487,7 +499,14 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
|
|||
|
||||
func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
|
||||
// fetch user access to static clients via policies
|
||||
|
||||
defer func() {
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
|
||||
return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
|
||||
}
|
||||
return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
|
||||
})
|
||||
}()
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
nodes, _ := GetNetworkNodes(node.Network)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue