diff --git a/mq/mq.go b/mq/mq.go index 6cd80dcc..f1f5553d 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -3,6 +3,7 @@ package mq import ( "encoding/json" "errors" + "log" "net" "strconv" "strings" @@ -13,6 +14,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -42,8 +44,8 @@ var Ping mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { return } node.SetLastCheckIn() - if err := logic.UpdateNode(&node, &node) ; err != nil { - logger.Log(0, "error updating node "+ err.Error()) + if err := logic.UpdateNode(&node, &node); err != nil { + logger.Log(0, "error updating node "+err.Error()) } logger.Log(0, "ping processed") // --TODO --set client version once feature is implemented. @@ -66,8 +68,8 @@ var PublicKeyUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Mess } node.PublicKey = key node.SetLastCheckIn() - if err := logic.UpdateNode(&node, &node) ; err != nil { - logger.Log(0, "error updating node "+ err.Error()) + if err := logic.UpdateNode(&node, &node); err != nil { + logger.Log(0, "error updating node "+err.Error()) } if err := UpdatePeers(client, node); err != nil { logger.Log(0, "error updating peers "+err.Error()) @@ -92,8 +94,8 @@ var IPUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { } node.Endpoint = ip node.SetLastCheckIn() - if err := logic.UpdateNode(&node, &node) ; err != nil { - logger.Log(0, "error updating node "+ err.Error()) + if err := logic.UpdateNode(&node, &node); err != nil { + logger.Log(0, "error updating node "+err.Error()) } if err != UpdatePeers(client, node) { logger.Log(0, "error updating peers "+err.Error()) @@ -106,50 +108,141 @@ func UpdatePeers(client mqtt.Client, newnode models.Node) error { if err != nil { return err } - keepalive, _ := time.ParseDuration(string(newnode.PersistentKeepalive)+"s") - for _, node := range networkNodes { - var peers []wgtypes.PeerConfig + dualstack := false + keepalive, _ := time.ParseDuration(string(newnode.PersistentKeepalive) + "s") + defaultkeepalive, _ := time.ParseDuration("25s") + for _, node := range networkNodes { + var peers []wgtypes.PeerConfig var peerUpdate models.PeerUpdate - for _, peer := range networkNodes{ - if peer.ID == node.ID { - //skip - continue - } - pubkey, err := wgtypes.ParseKey(peer.PublicKey) - if err != nil { + var gateways []string + + for _, peer := range networkNodes { + if peer.ID == node.ID { + //skip + continue + } + var allowedips []net.IPNet + var peeraddr = net.IPNet{ + IP: net.ParseIP(peer.Address), + Mask: net.CIDRMask(32, 32), + } + //hasGateway := false + pubkey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { return err - } - if node.Endpoint == peer.Endpoint { - if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" { - peer.Endpoint = peer.LocalAddress - }else { - continue - } - } - endpoint := peer.Endpoint + ":" + strconv.Itoa(int(peer.ListenPort)) - //fmt.Println("endpoint: ", endpoint, peer.Endpoint, peer.ListenPort) - address, err := net.ResolveUDPAddr("udp", endpoint) - if err != nil { + } + if node.Endpoint == peer.Endpoint { + if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" { + peer.Endpoint = peer.LocalAddress + } else { + continue + } + } + endpoint := peer.Endpoint + ":" + strconv.Itoa(int(peer.ListenPort)) + //fmt.Println("endpoint: ", endpoint, peer.Endpoint, peer.ListenPort) + address, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { return err - } - //calculate Allowed IPs. - var peerData wgtypes.PeerConfig - peerData = wgtypes.PeerConfig{ - PublicKey: pubkey, - Endpoint: address, - PersistentKeepaliveInterval: &keepalive, - //AllowedIPs: allowedIPs - } - peers = append (peers, peerData) - } + } + //calculate Allowed IPs. + allowedips = append(allowedips, peeraddr) + // handle manually set peers + for _, allowedIp := range node.AllowedIPs { + if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil { + nodeEndpointArr := strings.Split(node.Endpoint, ":") + if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists.. + allowedips = append(allowedips, *ipnet) + } + } else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address { + ipnet := net.IPNet{ + IP: net.ParseIP(allowedIp), + Mask: net.CIDRMask(32, 32), + } + allowedips = append(allowedips, ipnet) + } + } + // handle egress gateway peers + if node.IsEgressGateway == "yes" { + //hasGateway = true + ranges := node.EgressGatewayRanges + for _, iprange := range ranges { // go through each cidr for egress gateway + _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr + if err != nil { + ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1) + continue // if can't parse CIDR + } + nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node + if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node + ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2) + continue // skip adding egress range if overlaps with node's ip + } + if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node + ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2) + continue // skip adding egress range if overlaps with node's local ip + } + gateways = append(gateways, iprange) + if err != nil { + log.Println("ERROR ENCOUNTERED SETTING GATEWAY") + } else { + allowedips = append(allowedips, *ipnet) + } + } + } + var peerData wgtypes.PeerConfig + if node.Address6 != "" && dualstack { + var addr6 = net.IPNet{ + IP: net.ParseIP(node.Address6), + Mask: net.CIDRMask(128, 128), + } + allowedips = append(allowedips, addr6) + } + if node.IsServer == "yes" && !(node.IsServer == "yes") { + peerData = wgtypes.PeerConfig{ + PublicKey: pubkey, + PersistentKeepaliveInterval: &defaultkeepalive, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } else if keepalive != 0 { + peerData = wgtypes.PeerConfig{ + PublicKey: pubkey, + PersistentKeepaliveInterval: &defaultkeepalive, + //Endpoint: &net.UDPAddr{ + // IP: net.ParseIP(node.Endpoint), + // Port: int(node.ListenPort), + //}, + Endpoint: address, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } else { + peerData = wgtypes.PeerConfig{ + PublicKey: pubkey, + //Endpoint: &net.UDPAddr{ + // IP: net.ParseIP(node.Endpoint), + // Port: int(node.ListenPort), + //}, + Endpoint: address, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } + //peerData = wgtypes.PeerConfig{ + // PublicKey: pubkey, + // Endpoint: address, + // PersistentKeepaliveInterval: &keepalive, + //AllowedIPs: allowedIPs + //} + peers = append(peers, peerData) + } peerUpdate.Network = node.Network - peerUpdate.Peers = peers + peerUpdate.Peers = peers data, err := json.Marshal(&peerUpdate) if err != nil { logger.Log(0, "error marshaling peer update "+err.Error()) return err } - if token := client.Publish("/update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil { + if token := client.Publish("/update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil { logger.Log(0, "error sending peer updatte to no") return err } @@ -198,7 +291,7 @@ func NewPeer(node models.Node) error { if token := client.Connect(); token.Wait() && token.Error() != nil { return token.Error() } - + if err := UpdatePeers(client, node); err != nil { return err }