From f7258bf98fc5c4d65dc23d3370ca5d616f58fd77 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Wed, 2 Feb 2022 10:37:05 -0500 Subject: [PATCH] refactored some client leave & cache and server join logic --- controllers/node_grpc.go | 26 +++++++++++++++ logic/wireguard.go | 21 ++++++++++++ netclient/cli_options/cmds.go | 30 +++++++++--------- netclient/functions/common.go | 9 ++++-- netclient/functions/daemon.go | 60 +++++++++++++++++++---------------- 5 files changed, 101 insertions(+), 45 deletions(-) diff --git a/controllers/node_grpc.go b/controllers/node_grpc.go index 1969f9c1..b9863d9e 100644 --- a/controllers/node_grpc.go +++ b/controllers/node_grpc.go @@ -11,6 +11,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/serverctl" ) @@ -104,6 +105,31 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) runUpdates(&node, false) + go func(node *models.Node) { + if node.UDPHolePunch == "yes" { + var currentServerNodeID, getErr = logic.GetNetworkServerNodeID(node.Network) + if getErr != nil { + return + } + var currentServerNode, currErr = logic.GetNodeByID(currentServerNodeID) + if currErr != nil { + return + } + for i := 0; i < 5; i++ { + if logic.HasPeerConnected(node) { + if logic.ShouldPublishPeerPorts(¤tServerNode) { + err = mq.PublishPeerUpdate(¤tServerNode) + if err != nil { + logger.Log(1, "error publishing port updates when node", node.Name, "joined") + } + break + } + } + time.Sleep(time.Second << 1) // allow time for client to startup + } + } + }(&node) + return response, nil } diff --git a/logic/wireguard.go b/logic/wireguard.go index aac1a590..97a4e99b 100644 --- a/logic/wireguard.go +++ b/logic/wireguard.go @@ -25,6 +25,27 @@ func RemoveConf(iface string, printlog bool) error { return err } +// HasPeerConnected - checks if a client node has connected over WG +func HasPeerConnected(node *models.Node) bool { + client, err := wgctrl.New() + if err != nil { + return false + } + defer client.Close() + device, err := client.Device(node.Interface) + if err != nil { + return false + } + for _, peer := range device.Peers { + if peer.PublicKey.String() == node.PublicKey { + if peer.Endpoint != nil { + return true + } + } + } + return false +} + // == Private Functions == // gets the server peers locally diff --git a/netclient/cli_options/cmds.go b/netclient/cli_options/cmds.go index 5f90c1a8..4223a60c 100644 --- a/netclient/cli_options/cmds.go +++ b/netclient/cli_options/cmds.go @@ -62,21 +62,21 @@ func GetCommands(cliFlags []cli.Flag) []*cli.Command { return err }, }, - { - Name: "push", - Usage: "Push configuration changes to server.", - Flags: cliFlags, - // the action, or code that will be executed when - // we execute our `ns` command - Action: func(c *cli.Context) error { - cfg, _, err := config.GetCLIConfig(c) - if err != nil { - return err - } - err = command.Push(cfg) - return err - }, - }, + // { + // Name: "push", + // Usage: "Push configuration changes to server.", + // Flags: cliFlags, + // // the action, or code that will be executed when + // // we execute our `ns` command + // Action: func(c *cli.Context) error { + // cfg, _, err := config.GetCLIConfig(c) + // if err != nil { + // return err + // } + // err = command.Push(cfg) + // return err + // }, + // }, { Name: "pull", Usage: "Pull latest configuration and peers from server.", diff --git a/netclient/functions/common.go b/netclient/functions/common.go index 6dbdb2f9..8f4f1beb 100644 --- a/netclient/functions/common.go +++ b/netclient/functions/common.go @@ -185,7 +185,7 @@ func LeaveNetwork(network string) error { } } } - //extra network route setting required for freebsd and windows + // extra network route setting required for freebsd and windows, TODO mac?? if ncutils.IsWindows() { ip, mask, err := ncutils.GetNetworkIPMask(node.NetworkSettings.AddressRange) if err != nil { @@ -197,7 +197,12 @@ func LeaveNetwork(network string) error { } else if ncutils.IsLinux() { _, _ = ncutils.RunCmd("ip -4 route del "+node.NetworkSettings.AddressRange+" dev "+node.Interface, false) } - return RemoveLocalInstance(cfg, network) + + currentNets, err := ncutils.GetSystemNetworks() + if err != nil || len(currentNets) <= 1 { + return RemoveLocalInstance(cfg, network) + } + return daemon.Restart() } // RemoveLocalInstance - remove all netclient files locally for a network diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index 0282d8cb..6323f7a3 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -30,17 +30,31 @@ var messageCache = new(sync.Map) const lastNodeUpdate = "lnu" const lastPeerUpdate = "lpu" +type cachedMessage struct { + Message string + LastSeen time.Time +} + func insert(network, which, cache string) { - // var mu sync.Mutex - // mu.Lock() - // defer mu.Unlock() - messageCache.Store(fmt.Sprintf("%s%s", network, which), cache) + var newMessage = cachedMessage{ + Message: cache, + LastSeen: time.Now(), + } + ncutils.Log("storing new message: " + cache) + messageCache.Store(fmt.Sprintf("%s%s", network, which), newMessage) } func read(network, which string) string { val, isok := messageCache.Load(fmt.Sprintf("%s%s", network, which)) if isok { - return fmt.Sprintf("%v", val) + var readMessage = val.(cachedMessage) // fetch current cached message + if time.Now().After(readMessage.LastSeen.Add(time.Minute)) { // check if message has been there over a minute + messageCache.Delete(fmt.Sprintf("%s%s", network, which)) // remove old message if expired + ncutils.Log("cached message expired") + return "" + } + ncutils.Log("cache hit, skipping probably " + readMessage.Message) + return readMessage.Message // return current message if not expired } return "" } @@ -219,6 +233,7 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) { newNode.OS = runtime.GOOS // check if interface needs to delta ifaceDelta := ncutils.IfaceDelta(&cfg.Node, &newNode) + shouldDNSChange := cfg.Node.DNSOn != newNode.DNSOn cfg.Node = newNode switch newNode.Action { @@ -265,24 +280,15 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) { ncutils.Log("error resubscribing after interface change " + err.Error()) return } - } - /* - else { - ncutils.Log("syncing conf to " + file) - err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file) - if err != nil { - ncutils.Log("error syncing wg after peer update " + err.Error()) - return + if newNode.DNSOn == "yes" { + ncutils.Log("setting up DNS") + if err = local.UpdateDNS(cfg.Node.Interface, cfg.Network, cfg.Server.CoreDNSAddr); err != nil { + ncutils.Log("error applying dns" + err.Error()) } } - */ + } //deal with DNS - if newNode.DNSOn == "yes" { - ncutils.Log("setting up DNS") - if err = local.UpdateDNS(cfg.Node.Interface, cfg.Network, cfg.Server.CoreDNSAddr); err != nil { - ncutils.Log("error applying dns" + err.Error()) - } - } else { + if newNode.DNSOn != "yes" && shouldDNSChange { ncutils.Log("settng DNS off") _, err := ncutils.RunCmd("/usr/bin/resolvectl revert "+cfg.Node.Interface, true) if err != nil { @@ -311,14 +317,12 @@ func UpdatePeers(client mqtt.Client, msg mqtt.Message) { return } // see if cache hit, if so skip - /* - var currentMessage = read(peerUpdate.Network, lastPeerUpdate) - if currentMessage == string(data) { - return - } - */ + var currentMessage = read(peerUpdate.Network, lastPeerUpdate) + if currentMessage == string(data) { + ncutils.Log("cache hit") + return + } insert(peerUpdate.Network, lastPeerUpdate, string(data)) - ncutils.Log("update peer handler") file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf" err = wireguard.UpdateWgPeers(file, peerUpdate.Peers) @@ -326,13 +330,13 @@ func UpdatePeers(client mqtt.Client, msg mqtt.Message) { ncutils.Log("error updating wireguard peers" + err.Error()) return } - ncutils.Log("syncing conf to " + file) //err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file) err = wireguard.SetPeers(cfg.Node.Interface, cfg.Node.PersistentKeepalive, peerUpdate.Peers) if err != nil { ncutils.Log("error syncing wg after peer update " + err.Error()) return } + ncutils.Log(fmt.Sprintf("received peer update on network, %s", cfg.Network)) }() }