From 68b52279ae1c2cfa1610274f42a304b9f7200e2a Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Mon, 25 Apr 2022 16:30:18 -0400 Subject: [PATCH] added peers to pull/join responses --- controllers/node.go | 27 ++++++++++++++++-- models/structs.go | 7 +++++ netclient/functions/join.go | 7 +++-- netclient/functions/pull.go | 15 ++++------ netclient/wireguard/common.go | 53 +++++++++++++++++++++-------------- netclient/wireguard/unix.go | 41 --------------------------- 6 files changed, 73 insertions(+), 77 deletions(-) diff --git a/controllers/node.go b/controllers/node.go index f6dc48a7..249e3ae0 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -382,9 +382,21 @@ func getNode(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + + peerUpdate, err := logic.GetPeerUpdate(&node) + if err != nil && !database.IsEmptyRecord(err) { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + response := models.NodeGet{ + Node: node, + Peers: peerUpdate.Peers, + } + logger.Log(2, r.Header.Get("user"), "fetched node", params["nodeid"]) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(node) + json.NewEncoder(w).Encode(response) } //Get the time that a network of nodes was last modified. @@ -490,9 +502,20 @@ func createNode(w http.ResponseWriter, r *http.Request) { return } + peerUpdate, err := logic.GetPeerUpdate(&node) + if err != nil && !database.IsEmptyRecord(err) { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + response := models.NodeGet{ + Node: node, + Peers: peerUpdate.Peers, + } + logger.Log(1, r.Header.Get("user"), "created new node", node.Name, "on network", node.Network) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(node) + json.NewEncoder(w).Encode(response) runForceServerUpdate(&node) } diff --git a/models/structs.go b/models/structs.go index a827ccbe..5ddb6133 100644 --- a/models/structs.go +++ b/models/structs.go @@ -2,6 +2,7 @@ package models import ( jwt "github.com/golang-jwt/jwt/v4" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) const PLACEHOLDER_KEY_TEXT = "ACCESS_KEY" @@ -186,3 +187,9 @@ type TrafficKeys struct { Mine []byte `json:"mine" bson:"mine" yaml:"mine"` Server []byte `json:"server" bson:"server" yaml:"server"` } + +// NodeGet - struct for a single node get response +type NodeGet struct { + Node Node `json:"node" bson:"node" yaml:"node"` + Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` +} diff --git a/netclient/functions/join.go b/netclient/functions/join.go index 251e6ba1..4c7bf7d3 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -138,12 +138,13 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string) error { bodybytes, _ := io.ReadAll(response.Body) return fmt.Errorf("error creating node %s %s", response.Status, string(bodybytes)) } - var node models.Node - if err := json.NewDecoder(response.Body).Decode(&node); err != nil { + var nodeGET models.NodeGet + if err := json.NewDecoder(response.Body).Decode(&nodeGET); err != nil { //not sure the next line will work as response.Body probably needs to be reset before it can be read again bodybytes, _ := ioutil.ReadAll(response.Body) return fmt.Errorf("error decoding node from server %w %s", err, string(bodybytes)) } + node := nodeGET.Node // safety check. If returned node from server is local, but not currently configured as local, set to local addr if cfg.Node.IsLocal != "yes" && node.IsLocal == "yes" && node.LocalRange != "" { node.LocalAddress, err = ncutils.GetLocalIP(node.LocalRange) @@ -182,7 +183,7 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string) error { logger.Log(0, "failed to make backup, node will not auto restore if config is corrupted") } logger.Log(0, "starting wireguard") - err = wireguard.InitWireguard(&node, privateKey, []wgtypes.PeerConfig{}, false) + err = wireguard.InitWireguard(&node, privateKey, nodeGET.Peers[:], false) if err != nil { return err } diff --git a/netclient/functions/pull.go b/netclient/functions/pull.go index 8593d79f..79a7abb0 100644 --- a/netclient/functions/pull.go +++ b/netclient/functions/pull.go @@ -46,27 +46,22 @@ func Pull(network string, iface bool) (*models.Node, error) { return nil, (fmt.Errorf("%s %w", string(bytes), err)) } defer response.Body.Close() - resNode := models.Node{} - if err := json.NewDecoder(response.Body).Decode(&resNode); err != nil { + var nodeGET models.NodeGet + if err := json.NewDecoder(response.Body).Decode(&nodeGET); err != nil { return nil, fmt.Errorf("error decoding node %w", err) } + resNode := nodeGET.Node // ensure that the OS never changes resNode.OS = runtime.GOOS if iface { - // check for interface change - if cfg.Node.Interface != resNode.Interface { - if err = DeleteInterface(cfg.Node.Interface, cfg.Node.PostDown); err != nil { - logger.Log(1, "could not delete old interface ", cfg.Node.Interface) - } - } if err = config.ModConfig(&resNode); err != nil { return nil, err } - if err = wireguard.SetWGConfig(network, false); err != nil { + if err = wireguard.SetWGConfig(network, false, nodeGET.Peers[:]); err != nil { return nil, err } } else { - if err = wireguard.SetWGConfig(network, true); err != nil { + if err = wireguard.SetWGConfig(network, true, nodeGET.Peers[:]); err != nil { if errors.Is(err, os.ErrNotExist) && !ncutils.IsFreeBSD() { return Pull(network, true) } else { diff --git a/netclient/wireguard/common.go b/netclient/wireguard/common.go index 26ac629a..f1e0c6e4 100644 --- a/netclient/wireguard/common.go +++ b/netclient/wireguard/common.go @@ -30,24 +30,11 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error var keepalive = node.PersistentKeepalive var oldPeerAllowedIps = make(map[string][]net.IPNet, len(peers)) var err error - if ncutils.IsFreeBSD() { - if devicePeers, err = ncutils.GetPeers(iface); err != nil { - return err - } - } else { - client, err := wgctrl.New() - if err != nil { - logger.Log(0, "failed to start wgctrl") - return err - } - defer client.Close() - device, err := client.Device(iface) - if err != nil { - logger.Log(0, "failed to parse interface") - return err - } - devicePeers = device.Peers + devicePeers, err = GetDevicePeers(iface) + if err != nil { + return err } + if len(devicePeers) > 1 && len(peers) == 0 { logger.Log(1, "no peers pulled") return err @@ -235,7 +222,7 @@ func InitWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig } // SetWGConfig - sets the WireGuard Config of a given network and checks if it needs a peer update -func SetWGConfig(network string, peerupdate bool) error { +func SetWGConfig(network string, peerupdate bool, peers []wgtypes.PeerConfig) error { cfg, err := config.ReadConfig(network) if err != nil { @@ -257,11 +244,11 @@ func SetWGConfig(network string, peerupdate bool) error { return err } } - err = SetPeers(iface, &nodecfg, []wgtypes.PeerConfig{}) + err = SetPeers(iface, &nodecfg, peers) } else if peerupdate { - err = InitWireguard(&nodecfg, privkey, []wgtypes.PeerConfig{}, true) + err = InitWireguard(&nodecfg, privkey, peers, true) } else { - err = InitWireguard(&nodecfg, privkey, []wgtypes.PeerConfig{}, false) + err = InitWireguard(&nodecfg, privkey, peers, false) } if nodecfg.DNSOn == "yes" { _ = local.UpdateDNS(nodecfg.Interface, nodecfg.Network, servercfg.CoreDNSAddr) @@ -527,3 +514,27 @@ func RemoveConfGraceful(ifacename string) { } time.Sleep(time.Second << 1) } + +// GetDevicePeers - gets the current device's peers +func GetDevicePeers(iface string) ([]wgtypes.Peer, error) { + if ncutils.IsFreeBSD() { + if devicePeers, err := ncutils.GetPeers(iface); err != nil { + return nil, err + } else { + return devicePeers, nil + } + } else { + client, err := wgctrl.New() + if err != nil { + logger.Log(0, "failed to start wgctrl") + return nil, err + } + defer client.Close() + device, err := client.Device(iface) + if err != nil { + logger.Log(0, "failed to parse interface") + return nil, err + } + return device.Peers, nil + } +} diff --git a/netclient/wireguard/unix.go b/netclient/wireguard/unix.go index ddedf0e3..d0167e57 100644 --- a/netclient/wireguard/unix.go +++ b/netclient/wireguard/unix.go @@ -8,50 +8,9 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/ncutils" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// SetWGKeyConfig - sets the wg conf with a new private key -func SetWGKeyConfig(network string, serveraddr string) error { - - cfg, err := config.ReadConfig(network) - if err != nil { - return err - } - - node := cfg.Node - - privatekey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return err - } - privkeystring := privatekey.String() - publickey := privatekey.PublicKey() - - node.PublicKey = publickey.String() - - err = StorePrivKey(privkeystring, network) - if err != nil { - return err - } - if node.Action == models.NODE_UPDATE_KEY { - node.Action = models.NODE_NOOP - } - err = config.ModConfig(&node) - if err != nil { - return err - } - - err = SetWGConfig(network, false) - if err != nil { - return err - } - - return err -} - // ApplyWGQuickConf - applies wg-quick commands if os supports func ApplyWGQuickConf(confPath string, ifacename string) error { if ncutils.IsWindows() {