mirror of
				https://github.com/gravitl/netmaker.git
				synced 2025-11-04 11:39:22 +08:00 
			
		
		
		
	sort the acl rules
This commit is contained in:
		
							parent
							
								
									75307cb726
								
							
						
					
					
						commit
						c318c939f4
					
				
					 2 changed files with 68 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -5,6 +5,7 @@ import (
 | 
			
		|||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"maps"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
| 
						 | 
				
			
			@ -950,8 +951,12 @@ func getUserAclRulesForNode(targetnode *models.Node,
 | 
			
		|||
			if aclRule, ok := rules[acl.ID]; ok {
 | 
			
		||||
				aclRule.IPList = append(aclRule.IPList, r.IPList...)
 | 
			
		||||
				aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
 | 
			
		||||
				aclRule.IPList = UniqueIPNetList(aclRule.IPList)
 | 
			
		||||
				aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
 | 
			
		||||
				rules[acl.ID] = aclRule
 | 
			
		||||
			} else {
 | 
			
		||||
				r.IPList = UniqueIPNetList(r.IPList)
 | 
			
		||||
				r.IP6List = UniqueIPNetList(r.IP6List)
 | 
			
		||||
				rules[acl.ID] = r
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			@ -1115,9 +1120,52 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 | 
			
		|||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
 | 
			
		||||
				aclRule.IPList = UniqueIPNetList(aclRule.IPList)
 | 
			
		||||
				aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
 | 
			
		||||
				rules[acl.ID] = aclRule
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rules
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compare two IPs and return true if ip1 < ip2
 | 
			
		||||
func lessIP(ip1, ip2 net.IP) bool {
 | 
			
		||||
	ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
 | 
			
		||||
	ip2 = ip2.To16()
 | 
			
		||||
	return string(ip1) < string(ip2)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sort by IP first, then by prefix length
 | 
			
		||||
func sortIPNets(ipNets []net.IPNet) {
 | 
			
		||||
	sort.Slice(ipNets, func(i, j int) bool {
 | 
			
		||||
		ip1, ip2 := ipNets[i].IP, ipNets[j].IP
 | 
			
		||||
		mask1, _ := ipNets[i].Mask.Size()
 | 
			
		||||
		mask2, _ := ipNets[j].Mask.Size()
 | 
			
		||||
 | 
			
		||||
		// Compare IPs first
 | 
			
		||||
		if ip1.Equal(ip2) {
 | 
			
		||||
			return mask1 < mask2 // If same IP, sort by subnet mask size
 | 
			
		||||
		}
 | 
			
		||||
		return lessIP(ip1, ip2)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
 | 
			
		||||
	uniqueMap := make(map[string]net.IPNet)
 | 
			
		||||
 | 
			
		||||
	for _, ipnet := range ipnets {
 | 
			
		||||
		key := ipnet.String() // Uses CIDR notation as a unique key
 | 
			
		||||
		if _, exists := uniqueMap[key]; !exists {
 | 
			
		||||
			uniqueMap[key] = ipnet
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Convert map back to slice
 | 
			
		||||
	uniqueList := make([]net.IPNet, 0, len(uniqueMap))
 | 
			
		||||
	for _, ipnet := range uniqueMap {
 | 
			
		||||
		uniqueList = append(uniqueList, ipnet)
 | 
			
		||||
	}
 | 
			
		||||
	sortIPNets(uniqueList)
 | 
			
		||||
	return uniqueList
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,6 +6,7 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
| 
						 | 
				
			
			@ -463,7 +464,18 @@ func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.
 | 
			
		|||
	return newClient, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sort a slice of net.IP addresses
 | 
			
		||||
func sortIPs(ips []net.IP) {
 | 
			
		||||
	sort.Slice(ips, func(i, j int) bool {
 | 
			
		||||
		ip1, ip2 := ips[i].To16(), ips[j].To16()
 | 
			
		||||
		return string(ip1) < string(ip2) // Compare as byte slices
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		sortIPs(ips)
 | 
			
		||||
	}()
 | 
			
		||||
	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 | 
			
		||||
	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -487,7 +499,14 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 | 
			
		|||
 | 
			
		||||
func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 | 
			
		||||
	// fetch user access to static clients via policies
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		sort.Slice(rules, func(i, j int) bool {
 | 
			
		||||
			if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
 | 
			
		||||
				return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
 | 
			
		||||
			}
 | 
			
		||||
			return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
 | 
			
		||||
		})
 | 
			
		||||
	}()
 | 
			
		||||
	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 | 
			
		||||
	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 | 
			
		||||
	nodes, _ := GetNetworkNodes(node.Network)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue