From d6554ef081c4fed41f06064d99c83929f22d7f73 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Tue, 25 Jan 2022 22:14:31 -0500 Subject: [PATCH 1/2] initial commit --- controllers/node_grpc.go | 5 ++ logic/nodes.go | 15 ++++++ logic/peers.go | 5 ++ models/mqtt.go | 11 ++-- models/network.go | 9 ++-- netclient/functions/daemon.go | 78 ++++++++++++++++++++++++++++- netclient/ncutils/netclientutils.go | 10 ++++ 7 files changed, 123 insertions(+), 10 deletions(-) diff --git a/controllers/node_grpc.go b/controllers/node_grpc.go index c2fe6f19..b0130b79 100644 --- a/controllers/node_grpc.go +++ b/controllers/node_grpc.go @@ -67,6 +67,11 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) } } + var serverNodes = logic.GetServerNodes(node.Network) + for _, server := range serverNodes { + node.NetworkSettings.DefaultServerAddrs = append(node.NetworkSettings.DefaultServerAddrs, server.Address) + } + err = logic.CreateNode(&node) if err != nil { return nil, err diff --git a/logic/nodes.go b/logic/nodes.go index 48008db3..f6f74c4d 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -58,6 +58,21 @@ func GetSortedNetworkServerNodes(network string) ([]models.Node, error) { return nodes, nil } +// GetServerNodes - gets the server nodes of a network +func GetServerNodes(network string) []models.Node { + var nodes, err = GetNetworkNodes(network) + var serverNodes = make([]models.Node, 0) + if err != nil { + return serverNodes + } + for _, node := range nodes { + if node.IsServer == "yes" { + serverNodes = append(serverNodes, node) + } + } + return serverNodes +} + // UncordonNode - approves a node to join a network func UncordonNode(nodeid string) (models.Node, error) { node, err := GetNodeByID(nodeid) diff --git a/logic/peers.go b/logic/peers.go index 311ee2d6..caa0da25 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -20,6 +20,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { if err != nil { return models.PeerUpdate{}, err } + var serverNodeAddresses = []string{} for _, peer := range networkNodes { if peer.ID == node.ID { //skip yourself @@ -55,9 +56,13 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { PersistentKeepaliveInterval: &keepalive, } peers = append(peers, peerData) + if peer.IsServer == "yes" { + serverNodeAddresses = append(serverNodeAddresses, peer.Address) + } } peerUpdate.Network = node.Network peerUpdate.Peers = peers + peerUpdate.ServerAddrs = serverNodeAddresses return peerUpdate, nil } diff --git a/models/mqtt.go b/models/mqtt.go index 3ccab619..10b8d063 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -2,12 +2,15 @@ package models import "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +// PeerUpdate - struct type PeerUpdate struct { - Network string - Peers []wgtypes.PeerConfig + Network string `json:"network" bson:"network"` + ServerAddrs []string `json:"serversaddrs" bson:"serversaddrs"` + Peers []wgtypes.PeerConfig `json:"peers" bson:"peers"` } +// KeyUpdate - key update struct type KeyUpdate struct { - Network string - Interface string + Network string `json:"network" bson:"network"` + Interface string `json:"interface" bson:"interface"` } diff --git a/models/network.go b/models/network.go index c22dca0f..492b61d0 100644 --- a/models/network.go +++ b/models/network.go @@ -34,10 +34,11 @@ type Network struct { LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` // checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL - DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` - DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` - DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` - DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` + DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` + DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` + DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultServerAddrs []string `json:"defaultserveraddrs" bson:"defaultserveraddrs"` } // SaveData - sensitive fields of a network that should be kept the same diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index 1fbc51f7..d7c41e9f 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -3,10 +3,12 @@ package functions import ( "context" "encoding/json" + "fmt" "log" "os" "os/signal" "runtime" + "sync" "syscall" "time" @@ -19,6 +21,22 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +var messageCache = make(map[string]string, 20) + +const lastNodeUpdate = "lnu" +const lastPeerUpdate = "lpu" + +func insert(network, which, cache string) { + var mu sync.Mutex + mu.Lock() + defer mu.Unlock() + messageCache[fmt.Sprintf("%s%s", network, which)] = cache +} + +func read(network, which string) string { + return messageCache[fmt.Sprintf("%s%s", network, which)] +} + // Daemon runs netclient daemon from command line func Daemon() error { ctx, cancel := context.WithCancel(context.Background()) @@ -41,8 +59,12 @@ func Daemon() error { // SetupMQTT creates a connection to broker and return client func SetupMQTT(cfg *config.ClientConfig) mqtt.Client { opts := mqtt.NewClientOptions() - ncutils.Log("setting broker to " + cfg.Server.CoreDNSAddr + ":1883") - opts.AddBroker(cfg.Server.CoreDNSAddr + ":1883") + for i, addr := range cfg.Node.NetworkSettings.DefaultServerAddrs { + if addr != "" { + ncutils.Log(fmt.Sprintf("adding server (%d) to listen on network %s \n", (i + 1), cfg.Node.Network)) + opts.AddBroker(addr + ":1883") + } + } opts.SetDefaultPublishHandler(All) client := mqtt.NewClient(opts) if token := client.Connect(); token.Wait() && token.Error() != nil { @@ -102,6 +124,12 @@ var NodeUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) ncutils.Log("error unmarshalling node update data" + err.Error()) return } + // see if cache hit, if so skip + var currentMessage = read(newNode.Network, lastNodeUpdate) + if currentMessage == string(msg.Payload()) { + return + } + insert(newNode.Network, lastNodeUpdate, string(msg.Payload())) cfg.Network = newNode.Network cfg.ReadConfig() //check if interface name has changed if so delete. @@ -177,10 +205,24 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) ncutils.Log("error unmarshalling peer data") return } + // see if cache hit, if so skip + var currentMessage = read(peerUpdate.Network, lastPeerUpdate) + if currentMessage == string(msg.Payload()) { + return + } + insert(peerUpdate.Network, lastPeerUpdate, string(msg.Payload())) ncutils.Log("update peer handler") var cfg config.ClientConfig cfg.Network = peerUpdate.Network cfg.ReadConfig() + var shouldReSub = shouldResub(cfg.Node.NetworkSettings.DefaultServerAddrs, peerUpdate.ServerAddrs) + if shouldReSub { + client.Disconnect(250) // kill client + // un sub, re sub.. how? + client.Unsubscribe("update/"+cfg.Node.ID, "update/peers/"+cfg.Node.ID) + cfg.Node.NetworkSettings.DefaultServerAddrs = peerUpdate.ServerAddrs + + } file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf" err = wireguard.UpdateWgPeers(file, peerUpdate.Peers) if err != nil { @@ -196,6 +238,26 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) }() } +// Resubscribe --- handles resubscribing if needed +func Resubscribe(client mqtt.Client, cfg *config.ClientConfig) { + if err := config.ModConfig(&cfg.Node); err == nil { + client.Disconnect(250) + client = SetupMQTT(cfg) + if token := client.Subscribe("update/"+cfg.Node.ID, 0, NodeUpdate); token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + if cfg.DebugOn { + ncutils.Log("subscribed to node updates for node " + cfg.Node.Name + " update/" + cfg.Node.ID) + } + if token := client.Subscribe("update/peers/"+cfg.Node.ID, 0, UpdatePeers); token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + ncutils.Log("finished re subbing") + } else { + ncutils.Log("could not mod config when re-subbing") + } +} + // UpdateKeys -- updates private key and returns new publickey func UpdateKeys(cfg *config.ClientConfig, client mqtt.Client) error { ncutils.Log("received message to update keys") @@ -291,3 +353,15 @@ func Hello(cfg *config.ClientConfig, network string) { } client.Disconnect(250) } + +func shouldResub(currentServers, newServers []string) bool { + if len(currentServers) != len(newServers) { + return false + } + for _, srv := range currentServers { + if !ncutils.StringSliceContains(newServers, srv) { + return true + } + } + return false +} diff --git a/netclient/ncutils/netclientutils.go b/netclient/ncutils/netclientutils.go index 0bfff3c3..b1957e51 100644 --- a/netclient/ncutils/netclientutils.go +++ b/netclient/ncutils/netclientutils.go @@ -532,3 +532,13 @@ func CheckWG() { log.Println("running userspace WireGuard with " + uspace) } } + +// StringSliceContains - sees if a string slice contains a string element +func StringSliceContains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} From 424c801c6c9b762412946994c0b9f8b8ff072f33 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Tue, 25 Jan 2022 23:04:03 -0500 Subject: [PATCH 2/2] server update --- controllers/node_grpc.go | 7 +++++-- models/network.go | 10 +++++----- netclient/config/config.go | 1 + netclient/functions/daemon.go | 9 +++++---- netclient/functions/join.go | 2 ++ 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/controllers/node_grpc.go b/controllers/node_grpc.go index b0130b79..53bb8d99 100644 --- a/controllers/node_grpc.go +++ b/controllers/node_grpc.go @@ -68,8 +68,11 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) } var serverNodes = logic.GetServerNodes(node.Network) - for _, server := range serverNodes { - node.NetworkSettings.DefaultServerAddrs = append(node.NetworkSettings.DefaultServerAddrs, server.Address) + for i, server := range serverNodes { + node.NetworkSettings.DefaultServerAddrs += server.Address + if i < len(serverNodes)-1 { + node.NetworkSettings.DefaultServerAddrs += "," + } } err = logic.CreateNode(&node) diff --git a/models/network.go b/models/network.go index 492b61d0..0103199b 100644 --- a/models/network.go +++ b/models/network.go @@ -34,11 +34,11 @@ type Network struct { LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` // checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL - DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` - DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` - DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` - DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` - DefaultServerAddrs []string `json:"defaultserveraddrs" bson:"defaultserveraddrs"` + DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` + DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` + DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` + DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultServerAddrs string `json:"defaultserveraddrs" bson:"defaultserveraddrs" yaml:"defaultserveraddrs"` } // SaveData - sensitive fields of a network that should be kept the same diff --git a/netclient/config/config.go b/netclient/config/config.go index 66d19762..fb602bfa 100644 --- a/netclient/config/config.go +++ b/netclient/config/config.go @@ -119,6 +119,7 @@ func ModConfig(node *models.Node) error { modconfig.Node = (*node) modconfig.NetworkSettings = node.NetworkSettings + log.Printf("%v \n", modconfig) err = Write(&modconfig, network) return err } diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index d7c41e9f..4f01475b 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "runtime" + "strings" "sync" "syscall" "time" @@ -59,7 +60,8 @@ func Daemon() error { // SetupMQTT creates a connection to broker and return client func SetupMQTT(cfg *config.ClientConfig) mqtt.Client { opts := mqtt.NewClientOptions() - for i, addr := range cfg.Node.NetworkSettings.DefaultServerAddrs { + serverAddrs := strings.Split(cfg.Node.NetworkSettings.DefaultServerAddrs, ",") + for i, addr := range serverAddrs { if addr != "" { ncutils.Log(fmt.Sprintf("adding server (%d) to listen on network %s \n", (i + 1), cfg.Node.Network)) opts.AddBroker(addr + ":1883") @@ -215,13 +217,12 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) var cfg config.ClientConfig cfg.Network = peerUpdate.Network cfg.ReadConfig() - var shouldReSub = shouldResub(cfg.Node.NetworkSettings.DefaultServerAddrs, peerUpdate.ServerAddrs) + var shouldReSub = shouldResub(strings.Split(cfg.Node.NetworkSettings.DefaultServerAddrs, ","), peerUpdate.ServerAddrs) if shouldReSub { client.Disconnect(250) // kill client // un sub, re sub.. how? client.Unsubscribe("update/"+cfg.Node.ID, "update/peers/"+cfg.Node.ID) - cfg.Node.NetworkSettings.DefaultServerAddrs = peerUpdate.ServerAddrs - + cfg.Node.NetworkSettings.DefaultServerAddrs = strings.Join(peerUpdate.ServerAddrs, ",") } file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf" err = wireguard.UpdateWgPeers(file, peerUpdate.Peers) diff --git a/netclient/functions/join.go b/netclient/functions/join.go index f5c38c85..b37bf24c 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -161,6 +161,7 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { if err = json.Unmarshal([]byte(nodeData), &node); err != nil { return err } + log.Printf("%v \n", nodeData) } // get free port based on returned default listen port @@ -183,6 +184,7 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { } if node.IsServer != "yes" { // == handle client side == + cfg.Node = node err = config.ModConfig(&node) if err != nil { return err