From d3a6faa90d3c7a485f8c49d587aab9cf1c270f6e Mon Sep 17 00:00:00 2001 From: afeiszli Date: Wed, 15 Jun 2022 11:18:49 -0400 Subject: [PATCH] updating route setting logic --- logic/peers.go | 24 ++++++------------ netclient/local/routes.go | 39 +++++++++-------------------- netclient/ncutils/netclientutils.go | 26 +++++++++++++++++++ netclient/wireguard/common.go | 7 ++++-- 4 files changed, 51 insertions(+), 45 deletions(-) diff --git a/logic/peers.go b/logic/peers.go index 758f3db4..bded64b4 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -377,33 +377,25 @@ func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) { allowedips = append(allowedips, peer.AllowedIPs...) } //delete any ips not permitted by acl - for i, ip := range allowedips { - target, err := findNode(ip.IP.String()) + for i := len(allowedips) - 1; i >= 0; i-- { + target, err := findNode(allowedips[i].IP.String()) if err != nil { - logger.Log(0, "failed to find node for ip", ip.IP.String(), err.Error()) + logger.Log(0, "failed to find node for ip", allowedips[i].IP.String(), err.Error()) continue } if target == nil { - logger.Log(0, "failed to find node for ip", ip.IP.String()) + logger.Log(0, "failed to find node for ip", allowedips[i].IP.String()) continue } if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID), nodeacls.NodeID(target.ID)) { logger.Log(0, "deleting node from relayednode per acl", node.Name, target.Name) - if i+1 == len(allowedips) { - allowedips = allowedips[:i] - } else { - allowedips = append(allowedips[:i], allowedips[i+1:]...) - } + allowedips = append(allowedips[:i], allowedips[i+1:]...) } } //delete self from allowed ips - for i, ip := range allowedips { - if ip.IP.String() == node.Address || ip.IP.String() == node.Address6 { - if i+1 == len(allowedips) { - allowedips = allowedips[:i] - } else { - allowedips = append(allowedips[:i], allowedips[i+1:]...) - } + for i := len(allowedips) - 1; i >= 0; i-- { + if allowedips[i].IP.String() == node.Address || allowedips[i].IP.String() == node.Address6 { + allowedips = append(allowedips[:i], allowedips[i+1:]...) } } diff --git a/netclient/local/routes.go b/netclient/local/routes.go index 44008744..a9721298 100644 --- a/netclient/local/routes.go +++ b/netclient/local/routes.go @@ -11,41 +11,26 @@ import ( // TODO handle ipv6 in future // SetPeerRoutes - sets/removes ip routes for each peer on a network -func SetPeerRoutes(iface string, oldPeers map[string][]net.IPNet, newPeers []wgtypes.PeerConfig) { +func SetPeerRoutes(iface string, oldPeers map[string]bool, newPeers []wgtypes.PeerConfig) { // traverse through all recieved peers for _, peer := range newPeers { - // if pubkey found in existing peers, check against existing peer - currPeerAllowedIPs := oldPeers[peer.PublicKey.String()] - if currPeerAllowedIPs != nil { - // traverse IPs, check to see if old peer contains each IP - for _, allowedIP := range peer.AllowedIPs { // compare new ones (if any) to old ones - if !ncutils.IPNetSliceContains(currPeerAllowedIPs, allowedIP) { - if err := setRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil { - logger.Log(1, err.Error()) - } - } - } - for _, allowedIP := range currPeerAllowedIPs { // compare old ones (if any) to new ones - if !ncutils.IPNetSliceContains(peer.AllowedIPs, allowedIP) { - if err := deleteRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil { - logger.Log(1, err.Error()) - } - } - } - delete(oldPeers, peer.PublicKey.String()) // remove peer as it was found and processed - } else { - for _, allowedIP := range peer.AllowedIPs { // add all routes as peer doesn't exist - if err := setRoute(iface, &allowedIP, allowedIP.String()); err != nil { + for _, allowedIP := range peer.AllowedIPs { + if !oldPeers[allowedIP.String()] { + if err := setRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil { logger.Log(1, err.Error()) } + } else { + delete(oldPeers, allowedIP.String()) } } } - // traverse through all remaining existing peers - for _, allowedIPs := range oldPeers { - for _, allowedIP := range allowedIPs { - deleteRoute(iface, &allowedIP, allowedIP.IP.String()) + for i, _ := range oldPeers { + ip, err := ncutils.GetIPNetFromString(i) + if err != nil { + logger.Log(1, err.Error()) + } else { + deleteRoute(iface, &ip, ip.IP.String()) } } } diff --git a/netclient/ncutils/netclientutils.go b/netclient/ncutils/netclientutils.go index 30cb8b82..317fcac6 100644 --- a/netclient/ncutils/netclientutils.go +++ b/netclient/ncutils/netclientutils.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/c-robinson/iplib" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -593,3 +594,28 @@ func MakeRandomString(n int) string { } return string(result) } + +func GetIPNetFromString(ip string) (net.IPNet, error) { + var ipnet *net.IPNet + var err error + // parsing as a CIDR first. If valid CIDR, append + if _, cidr, err := net.ParseCIDR(ip); err == nil { + ipnet = cidr + } else { // parsing as an IP second. If valid IP, check if ipv4 or ipv6, then append + if iplib.Version(net.ParseIP(ip)) == 4 { + ipnet = &net.IPNet{ + IP: net.ParseIP(ip), + Mask: net.CIDRMask(32, 32), + } + } else if iplib.Version(net.ParseIP(ip)) == 6 { + ipnet = &net.IPNet{ + IP: net.ParseIP(ip), + Mask: net.CIDRMask(128, 128), + } + } + } + if ipnet == nil { + err = errors.New(ip + " is not a valid ip or cidr") + } + return *ipnet, err +} diff --git a/netclient/wireguard/common.go b/netclient/wireguard/common.go index fe1508e9..b4dc8aab 100644 --- a/netclient/wireguard/common.go +++ b/netclient/wireguard/common.go @@ -28,7 +28,8 @@ const ( func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error { var devicePeers []wgtypes.Peer var keepalive = node.PersistentKeepalive - var oldPeerAllowedIps = make(map[string][]net.IPNet, len(peers)) + var oldPeerAllowedIps = make(map[string]bool, len(peers)) + var err error devicePeers, err = GetDevicePeers(iface) if err != nil { @@ -106,7 +107,9 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error log.Println(output, "error removing peer", currentPeer.PublicKey.String()) } } - oldPeerAllowedIps[currentPeer.PublicKey.String()] = currentPeer.AllowedIPs + for _, ip := range currentPeer.AllowedIPs { + oldPeerAllowedIps[ip.String()] = true + } } } }