diff --git a/controllers/node_grpc.go b/controllers/node_grpc.go index c2fe6f19..8edebf5f 100644 --- a/controllers/node_grpc.go +++ b/controllers/node_grpc.go @@ -67,6 +67,16 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object) } } + var serverNodes = logic.GetServerNodes(node.Network) + var serverAddrs = make([]models.ServerAddr, len(serverNodes)) + for i, server := range serverNodes { + serverAddrs[i] = models.ServerAddr{ + IsLeader: logic.IsLeader(&server), + Address: server.Address, + } + } + node.NetworkSettings.DefaultServerAddrs = serverAddrs + err = logic.CreateNode(&node) if err != nil { return nil, err diff --git a/logger/logger.go b/logger/logger.go index 6ffd6ad0..b01ffc65 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -6,6 +6,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" ) @@ -36,6 +37,9 @@ func ResetLogs() { // Log - handles adding logs func Log(verbosity int, message ...string) { + var mu sync.Mutex + mu.Lock() + defer mu.Unlock() var currentTime = time.Now() var currentMessage = makeString(message...) if int32(verbosity) <= getVerbose() && getVerbose() >= 0 { diff --git a/logic/nodes.go b/logic/nodes.go index 48008db3..f6f74c4d 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -58,6 +58,21 @@ func GetSortedNetworkServerNodes(network string) ([]models.Node, error) { return nodes, nil } +// GetServerNodes - gets the server nodes of a network +func GetServerNodes(network string) []models.Node { + var nodes, err = GetNetworkNodes(network) + var serverNodes = make([]models.Node, 0) + if err != nil { + return serverNodes + } + for _, node := range nodes { + if node.IsServer == "yes" { + serverNodes = append(serverNodes, node) + } + } + return serverNodes +} + // UncordonNode - approves a node to join a network func UncordonNode(nodeid string) (models.Node, error) { node, err := GetNodeByID(nodeid) diff --git a/logic/peers.go b/logic/peers.go index 311ee2d6..44e4ae5c 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -20,6 +20,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { if err != nil { return models.PeerUpdate{}, err } + var serverNodeAddresses = []models.ServerAddr{} for _, peer := range networkNodes { if peer.ID == node.ID { //skip yourself @@ -55,9 +56,13 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { PersistentKeepaliveInterval: &keepalive, } peers = append(peers, peerData) + if peer.IsServer == "yes" { + serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address}) + } } peerUpdate.Network = node.Network peerUpdate.Peers = peers + peerUpdate.ServerAddrs = serverNodeAddresses return peerUpdate, nil } diff --git a/models/mqtt.go b/models/mqtt.go index 3ccab619..e9bead41 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -2,12 +2,15 @@ package models import "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +// PeerUpdate - struct type PeerUpdate struct { - Network string - Peers []wgtypes.PeerConfig + Network string `json:"network" bson:"network" yaml:"network"` + ServerAddrs []ServerAddr `json:"serveraddrs" bson:"serveraddrs" yaml:"serveraddrs"` + Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` } +// KeyUpdate - key update struct type KeyUpdate struct { - Network string - Interface string + Network string `json:"network" bson:"network"` + Interface string `json:"interface" bson:"interface"` } diff --git a/models/network.go b/models/network.go index c22dca0f..690faeb8 100644 --- a/models/network.go +++ b/models/network.go @@ -34,10 +34,11 @@ type Network struct { LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` // checkin interval is depreciated at the network level. Set on server with CHECKIN_INTERVAL - DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` - DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` - DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` - DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultCheckInInterval int32 `json:"checkininterval,omitempty" bson:"checkininterval,omitempty" validate:"omitempty,numeric,min=2,max=100000"` + DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` + DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` + DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultServerAddrs []ServerAddr `json:"defaultserveraddrs" bson:"defaultserveraddrs" yaml:"defaultserveraddrs"` } // SaveData - sensitive fields of a network that should be kept the same diff --git a/models/structs.go b/models/structs.go index 3e772bc3..3020c1bd 100644 --- a/models/structs.go +++ b/models/structs.go @@ -169,3 +169,9 @@ type Telemetry struct { UUID string `json:"uuid" bson:"uuid"` LastSend int64 `json:"lastsend" bson:"lastsend"` } + +// ServerAddr - to pass to clients to tell server addresses and if it's the leader or not +type ServerAddr struct { + IsLeader bool `json:"isleader" bson:"isleader" yaml:"isleader"` + Address string `json:"address" bson:"address" yaml:"address"` +} diff --git a/netclient/functions/daemon.go b/netclient/functions/daemon.go index 1fbc51f7..a8d31bb6 100644 --- a/netclient/functions/daemon.go +++ b/netclient/functions/daemon.go @@ -3,10 +3,12 @@ package functions import ( "context" "encoding/json" + "fmt" "log" "os" "os/signal" "runtime" + "sync" "syscall" "time" @@ -19,6 +21,22 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +var messageCache = make(map[string]string, 20) + +const lastNodeUpdate = "lnu" +const lastPeerUpdate = "lpu" + +func insert(network, which, cache string) { + var mu sync.Mutex + mu.Lock() + defer mu.Unlock() + messageCache[fmt.Sprintf("%s%s", network, which)] = cache +} + +func read(network, which string) string { + return messageCache[fmt.Sprintf("%s%s", network, which)] +} + // Daemon runs netclient daemon from command line func Daemon() error { ctx, cancel := context.WithCancel(context.Background()) @@ -41,8 +59,13 @@ func Daemon() error { // SetupMQTT creates a connection to broker and return client func SetupMQTT(cfg *config.ClientConfig) mqtt.Client { opts := mqtt.NewClientOptions() - ncutils.Log("setting broker to " + cfg.Server.CoreDNSAddr + ":1883") - opts.AddBroker(cfg.Server.CoreDNSAddr + ":1883") + for _, server := range cfg.Node.NetworkSettings.DefaultServerAddrs { + if server.Address != "" && server.IsLeader { + ncutils.Log(fmt.Sprintf("adding server (%s) to listen on network %s \n", server.Address, cfg.Node.Network)) + opts.AddBroker(server.Address + ":1883") + break + } + } opts.SetDefaultPublishHandler(All) client := mqtt.NewClient(opts) if token := client.Connect(); token.Wait() && token.Error() != nil { @@ -65,13 +88,13 @@ func MessageQueue(ctx context.Context, network string) { } ncutils.Log("subscribed to all topics for debugging purposes") } - if token := client.Subscribe("update/"+cfg.Node.ID, 0, NodeUpdate); token.Wait() && token.Error() != nil { + if token := client.Subscribe("update/"+cfg.Node.ID, 0, mqtt.MessageHandler(NodeUpdate)); token.Wait() && token.Error() != nil { log.Fatal(token.Error()) } if cfg.DebugOn { ncutils.Log("subscribed to node updates for node " + cfg.Node.Name + " update/" + cfg.Node.ID) } - if token := client.Subscribe("update/peers/"+cfg.Node.ID, 0, UpdatePeers); token.Wait() && token.Error() != nil { + if token := client.Subscribe("update/peers/"+cfg.Node.ID, 0, mqtt.MessageHandler(UpdatePeers)); token.Wait() && token.Error() != nil { log.Fatal(token.Error()) } if cfg.DebugOn { @@ -91,8 +114,7 @@ var All mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { } // NodeUpdate -- mqtt message handler for /update/ topic -var NodeUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { - ncutils.Log("received message to update node " + string(msg.Payload())) +func NodeUpdate(client mqtt.Client, msg mqtt.Message) { //potentiall blocking i/o so do this in a go routine go func() { var newNode models.Node @@ -102,6 +124,13 @@ var NodeUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) ncutils.Log("error unmarshalling node update data" + err.Error()) return } + ncutils.Log("received message to update node " + newNode.Name) + // see if cache hit, if so skip + var currentMessage = read(newNode.Network, lastNodeUpdate) + if currentMessage == string(msg.Payload()) { + return + } + insert(newNode.Network, lastNodeUpdate, string(msg.Payload())) cfg.Network = newNode.Network cfg.ReadConfig() //check if interface name has changed if so delete. @@ -169,7 +198,7 @@ var NodeUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) } // UpdatePeers -- mqtt message handler for /update/peers/ topic -var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) { +func UpdatePeers(client mqtt.Client, msg mqtt.Message) { go func() { var peerUpdate models.PeerUpdate err := json.Unmarshal(msg.Payload(), &peerUpdate) @@ -177,10 +206,21 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) ncutils.Log("error unmarshalling peer data") return } + // see if cache hit, if so skip + var currentMessage = read(peerUpdate.Network, lastPeerUpdate) + if currentMessage == string(msg.Payload()) { + return + } + insert(peerUpdate.Network, lastPeerUpdate, string(msg.Payload())) ncutils.Log("update peer handler") var cfg config.ClientConfig cfg.Network = peerUpdate.Network cfg.ReadConfig() + var shouldReSub = shouldResub(cfg.Node.NetworkSettings.DefaultServerAddrs, peerUpdate.ServerAddrs) + if shouldReSub { + Resubscribe(client, &cfg) + cfg.Node.NetworkSettings.DefaultServerAddrs = peerUpdate.ServerAddrs + } file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf" err = wireguard.UpdateWgPeers(file, peerUpdate.Peers) if err != nil { @@ -196,6 +236,29 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) }() } +// Resubscribe --- handles resubscribing if needed +func Resubscribe(client mqtt.Client, cfg *config.ClientConfig) error { + if err := config.ModConfig(&cfg.Node); err == nil { + ncutils.Log("resubbing on network " + cfg.Node.Network) + client.Disconnect(250) + client = SetupMQTT(cfg) + if token := client.Subscribe("update/"+cfg.Node.ID, 0, NodeUpdate); token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + if cfg.DebugOn { + ncutils.Log("subscribed to node updates for node " + cfg.Node.Name + " update/" + cfg.Node.ID) + } + if token := client.Subscribe("update/peers/"+cfg.Node.ID, 0, UpdatePeers); token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + ncutils.Log("finished re subbing") + return nil + } else { + ncutils.Log("could not mod config when re-subbing") + return err + } +} + // UpdateKeys -- updates private key and returns new publickey func UpdateKeys(cfg *config.ClientConfig, client mqtt.Client) error { ncutils.Log("received message to update keys") @@ -291,3 +354,15 @@ func Hello(cfg *config.ClientConfig, network string) { } client.Disconnect(250) } + +func shouldResub(currentServers, newServers []models.ServerAddr) bool { + if len(currentServers) != len(newServers) { + return true + } + for _, srv := range currentServers { + if !ncutils.ServerAddrSliceContains(newServers, srv) { + return true + } + } + return false +} diff --git a/netclient/functions/join.go b/netclient/functions/join.go index f5c38c85..c8451f42 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os/exec" + "runtime" "github.com/google/uuid" nodepb "github.com/gravitl/netmaker/grpc" @@ -101,8 +102,7 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { // make sure name is appropriate, if not, give blank name cfg.Node.Name = formatName(cfg.Node) // differentiate between client/server here - var node models.Node // fill this node with appropriate calls - postnode := &models.Node{ + var node = models.Node{ Password: cfg.Node.Password, ID: cfg.Node.ID, MacAddress: cfg.Node.MacAddress, @@ -124,44 +124,17 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { UDPHolePunch: cfg.Node.UDPHolePunch, } - if cfg.Node.IsServer != "yes" { - ncutils.Log("joining " + cfg.Network + " at " + cfg.Server.GRPCAddress) - var wcclient nodepb.NodeServiceClient + ncutils.Log("joining " + cfg.Network + " at " + cfg.Server.GRPCAddress) + var wcclient nodepb.NodeServiceClient - conn, err := grpc.Dial(cfg.Server.GRPCAddress, - ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) + conn, err := grpc.Dial(cfg.Server.GRPCAddress, + ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) - if err != nil { - log.Fatalf("Unable to establish client connection to "+cfg.Server.GRPCAddress+": %v", err) - } - defer conn.Close() - wcclient = nodepb.NewNodeServiceClient(conn) - - if err = config.ModConfig(postnode); err != nil { - return err - } - data, err := json.Marshal(postnode) - if err != nil { - return err - } - // Create node on server - res, err := wcclient.CreateNode( - context.TODO(), - &nodepb.Object{ - Data: string(data), - Type: nodepb.NODE_TYPE, - }, - ) - if err != nil { - return err - } - ncutils.PrintLog("node created on remote server...updating configs", 1) - - nodeData := res.Data - if err = json.Unmarshal([]byte(nodeData), &node); err != nil { - return err - } + if err != nil { + log.Fatalf("Unable to establish client connection to "+cfg.Server.GRPCAddress+": %v", err) } + defer conn.Close() + wcclient = nodepb.NewNodeServiceClient(conn) // get free port based on returned default listen port node.ListenPort, err = ncutils.GetFreePort(node.ListenPort) @@ -182,32 +155,48 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { cfg.Node.IsStatic = "yes" } - if node.IsServer != "yes" { // == handle client side == - err = config.ModConfig(&node) - if err != nil { - return err - } - err = wireguard.StorePrivKey(privateKey, cfg.Network) - if err != nil { - return err - } - if node.IsPending == "yes" { - ncutils.Log("Node is marked as PENDING.") - ncutils.Log("Awaiting approval from Admin before configuring WireGuard.") - if cfg.Daemon != "off" { - return daemon.InstallDaemon(cfg) - } - } - // pushing any local changes to server before starting wireguard - err = Push(cfg.Network) - if err != nil { - return err - } - // attempt to make backup - if err = config.SaveBackup(node.Network); err != nil { - ncutils.Log("failed to make backup, node will not auto restore if config is corrupted") + err = wireguard.StorePrivKey(privateKey, cfg.Network) + if err != nil { + return err + } + if node.IsPending == "yes" { + ncutils.Log("Node is marked as PENDING.") + ncutils.Log("Awaiting approval from Admin before configuring WireGuard.") + if cfg.Daemon != "off" { + return daemon.InstallDaemon(cfg) } } + data, err := json.Marshal(&node) + if err != nil { + return err + } + // Create node on server + res, err := wcclient.CreateNode( + context.TODO(), + &nodepb.Object{ + Data: string(data), + Type: nodepb.NODE_TYPE, + }, + ) + if err != nil { + return err + } + ncutils.PrintLog("node created on remote server...updating configs", 1) + + nodeData := res.Data + if err = json.Unmarshal([]byte(nodeData), &node); err != nil { + return err + } + node.OS = runtime.GOOS + cfg.Node = node + err = config.ModConfig(&node) + if err != nil { + return err + } + // attempt to make backup + if err = config.SaveBackup(node.Network); err != nil { + ncutils.Log("failed to make backup, node will not auto restore if config is corrupted") + } ncutils.Log("retrieving peers") peers, hasGateway, gateways, err := server.GetPeers(node.MacAddress, cfg.Network, cfg.Server.GRPCAddress, node.IsDualStack == "yes", node.IsIngressGateway == "yes", node.IsServer == "yes") diff --git a/netclient/ncutils/netclientutils.go b/netclient/ncutils/netclientutils.go index 0bfff3c3..0f158229 100644 --- a/netclient/ncutils/netclientutils.go +++ b/netclient/ncutils/netclientutils.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "github.com/gravitl/netmaker/models" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -532,3 +533,13 @@ func CheckWG() { log.Println("running userspace WireGuard with " + uspace) } } + +// ServerAddrSliceContains - sees if a string slice contains a string element +func ServerAddrSliceContains(slice []models.ServerAddr, item models.ServerAddr) bool { + for _, s := range slice { + if s.Address == item.Address && s.IsLeader == item.IsLeader { + return true + } + } + return false +}