diff --git a/database/database.go b/database/database.go index 497bb94a..2c4d50d6 100644 --- a/database/database.go +++ b/database/database.go @@ -3,33 +3,74 @@ package database import ( "encoding/json" "errors" - "github.com/gravitl/netmaker/servercfg" "log" "time" + + "github.com/gravitl/netmaker/servercfg" ) +// NETWORKS_TABLE_NAME - networks table const NETWORKS_TABLE_NAME = "networks" + +// NODES_TABLE_NAME - nodes table const NODES_TABLE_NAME = "nodes" + +// DELETED_NODES_TABLE_NAME - deleted nodes table const DELETED_NODES_TABLE_NAME = "deletednodes" + +// USERS_TABLE_NAME - users table const USERS_TABLE_NAME = "users" + +// DNS_TABLE_NAME - dns table const DNS_TABLE_NAME = "dns" + +// EXT_CLIENT_TABLE_NAME - ext client table const EXT_CLIENT_TABLE_NAME = "extclients" + +// INT_CLIENTS_TABLE_NAME - int client table const INT_CLIENTS_TABLE_NAME = "intclients" + +// PEERS_TABLE_NAME - peers table const PEERS_TABLE_NAME = "peers" + +// SERVERCONF_TABLE_NAME +const SERVERCONF_TABLE_NAME = "serverconf" + +// DATABASE_FILENAME - database file name const DATABASE_FILENAME = "netmaker.db" // == ERROR CONSTS == + +// NO_RECORD - no singular result found const NO_RECORD = "no result found" + +// NO_RECORDS - no results found const NO_RECORDS = "could not find any records" // == Constants == + +// INIT_DB - initialize db const INIT_DB = "init" + +// CREATE_TABLE - create table const const CREATE_TABLE = "createtable" + +// INSERT - insert into db const const INSERT = "insert" + +// INSERT_PEER - insert peer into db const const INSERT_PEER = "insertpeer" + +// DELETE - delete db record const const DELETE = "delete" + +// DELETE_ALL - delete a table const const DELETE_ALL = "deleteall" + +// FETCH_ALL - fetch table contents const const FETCH_ALL = "fetchall" + +// CLOSE_DB - graceful close of db const const CLOSE_DB = "closedb" func getCurrentDB() map[string]interface{} { @@ -72,17 +113,20 @@ func createTables() { createTable(EXT_CLIENT_TABLE_NAME) createTable(INT_CLIENTS_TABLE_NAME) createTable(PEERS_TABLE_NAME) + createTable(SERVERCONF_TABLE_NAME) } func createTable(tableName string) error { return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName) } +// IsJSONString - checks if valid json func IsJSONString(value string) bool { var jsonInt interface{} return json.Unmarshal([]byte(value), &jsonInt) == nil } +// Insert - inserts object into db func Insert(key string, value string, tableName string) error { if key != "" && value != "" && IsJSONString(value) { return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName) @@ -91,6 +135,7 @@ func Insert(key string, value string, tableName string) error { } } +// InsertPeer - inserts peer into db func InsertPeer(key string, value string) error { if key != "" && value != "" && IsJSONString(value) { return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value) @@ -99,10 +144,12 @@ func InsertPeer(key string, value string) error { } } +// DeleteRecord - deletes a record from db func DeleteRecord(tableName string, key string) error { return getCurrentDB()[DELETE].(func(string, string) error)(tableName, key) } +// DeleteAllRecords - removes a table and remakes func DeleteAllRecords(tableName string) error { err := getCurrentDB()[DELETE_ALL].(func(string) error)(tableName) if err != nil { @@ -115,6 +162,7 @@ func DeleteAllRecords(tableName string) error { return nil } +// FetchRecord - fetches a record func FetchRecord(tableName string, key string) (string, error) { results, err := FetchRecords(tableName) if err != nil { @@ -126,10 +174,12 @@ func FetchRecord(tableName string, key string) (string, error) { return results[key], nil } +// FetchRecords - fetches all records in given table func FetchRecords(tableName string) (map[string]string, error) { return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName) } +// CloseDB - closes a database gracefully func CloseDB() { getCurrentDB()[CLOSE_DB].(func())() } diff --git a/database/postgres.go b/database/postgres.go index f6847ea9..74b30e1a 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -1,15 +1,17 @@ package database import ( - "github.com/gravitl/netmaker/servercfg" "database/sql" "errors" - _ "github.com/lib/pq" "fmt" + "github.com/gravitl/netmaker/servercfg" + _ "github.com/lib/pq" ) +// PGDB - database object for PostGreSQL var PGDB *sql.DB +// PG_FUNCTIONS - map of db functions for PostGreSQL var PG_FUNCTIONS = map[string]interface{}{ INIT_DB: initPGDB, CREATE_TABLE: pgCreateTable, @@ -21,14 +23,13 @@ var PG_FUNCTIONS = map[string]interface{}{ CLOSE_DB: pgCloseDB, } -func getPGConnString() string{ +func getPGConnString() string { pgconf := servercfg.GetSQLConf() pgConn := fmt.Sprintf("host=%s port=%d user=%s "+ - "password=%s dbname=%s sslmode=%s", - pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode) + "password=%s dbname=%s sslmode=%s", + pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode) return pgConn } - func initPGDB() error { connString := getPGConnString() diff --git a/logic/network.go b/logic/network.go new file mode 100644 index 00000000..bc4583de --- /dev/null +++ b/logic/network.go @@ -0,0 +1,63 @@ +package logic + +import ( + "net" + + "github.com/gravitl/netmaker/models" +) + +// GetLocalIP - gets the local ip +func GetLocalIP(node models.Node) string { + + var local string + + ifaces, err := net.Interfaces() + if err != nil { + return local + } + _, localrange, err := net.ParseCIDR(node.LocalRange) + if err != nil { + return local + } + + found := false + for _, i := range ifaces { + if i.Flags&net.FlagUp == 0 { + continue // interface down + } + if i.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := i.Addrs() + if err != nil { + return local + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + if !found { + ip = v.IP + local = ip.String() + if node.IsLocal == "yes" { + found = localrange.Contains(ip) + } else { + found = true + } + } + case *net.IPAddr: + if !found { + ip = v.IP + local = ip.String() + if node.IsLocal == "yes" { + found = localrange.Contains(ip) + + } else { + found = true + } + } + } + } + } + return local +} diff --git a/logic/server.go b/logic/server.go new file mode 100644 index 00000000..9163bd3b --- /dev/null +++ b/logic/server.go @@ -0,0 +1,353 @@ +package logic + +import ( + "errors" + "log" + "net" + "runtime" + "strconv" + "strings" + "time" + + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/config" + "github.com/gravitl/netmaker/netclient/ncutils" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// == Join, Checkin, and Leave for Server == +func ServerJoin(cfg config.ClientConfig, privateKey string) error { + var err error + + if cfg.Network == "" { + return errors.New("no network provided") + } + + if cfg.Node.LocalRange != "" && cfg.Node.LocalAddress == "" { + Log("local vpn, getting local address from range: "+cfg.Node.LocalRange, 1) + cfg.Node.LocalAddress = GetLocalIP(cfg.Node) + } + + if cfg.Node.Endpoint == "" { + if cfg.Node.IsLocal == "yes" && cfg.Node.LocalAddress != "" { + cfg.Node.Endpoint = cfg.Node.LocalAddress + } else { + cfg.Node.Endpoint, err = ncutils.GetPublicIP() + } + if err != nil || cfg.Node.Endpoint == "" { + ncutils.Log("Error setting cfg.Node.Endpoint.") + return err + } + } + + // Generate and set public/private WireGuard Keys + if privateKey == "" { + wgPrivatekey, err := wgtypes.GeneratePrivateKey() + if err != nil { + Log(err.Error(), 1) + return err + } + privateKey = wgPrivatekey.String() + cfg.Node.PublicKey = wgPrivatekey.PublicKey().String() + } + + if cfg.Node.MacAddress == "" { + macs, err := ncutils.GetMacAddr() + if err != nil { + return err + } else if len(macs) == 0 { + Log("could not retrieve mac address for server", 1) + return errors.New("failed to get server mac") + } else { + cfg.Node.MacAddress = macs[0] + } + } + + var node models.Node // fill this node with appropriate calls + var postnode *models.Node + postnode = &models.Node{ + Password: cfg.Node.Password, + MacAddress: cfg.Node.MacAddress, + AccessKey: cfg.Server.AccessKey, + Network: cfg.Network, + ListenPort: cfg.Node.ListenPort, + PostUp: cfg.Node.PostUp, + PostDown: cfg.Node.PostDown, + PersistentKeepalive: cfg.Node.PersistentKeepalive, + LocalAddress: cfg.Node.LocalAddress, + Interface: cfg.Node.Interface, + PublicKey: cfg.Node.PublicKey, + DNSOn: cfg.Node.DNSOn, + Name: cfg.Node.Name, + Endpoint: cfg.Node.Endpoint, + SaveConfig: cfg.Node.SaveConfig, + UDPHolePunch: cfg.Node.UDPHolePunch, + } + + Log("adding a server instance on network "+postnode.Network, 2) + node, err = CreateNode(*postnode, cfg.Network) + if err != nil { + return err + } + err = SetNetworkNodesLastModified(node.Network) + if err != nil { + return err + } + + // get free port based on returned default listen port + node.ListenPort, err = ncutils.GetFreePort(node.ListenPort) + if err != nil { + Log("Error retrieving port: "+err.Error(), 2) + } + + // safety check. If returned node from server is local, but not currently configured as local, set to local addr + if cfg.Node.IsLocal != "yes" && node.IsLocal == "yes" && node.LocalRange != "" { + node.LocalAddress, err = ncutils.GetLocalIP(node.LocalRange) + if err != nil { + return err + } + node.Endpoint = node.LocalAddress + } + + node.SetID() + if err = StorePrivKey(node.ID, privateKey); err != nil { + return err + } + if err = ServerPush(node.MacAddress, node.Network); err != nil { + return err + } + + peers, hasGateway, gateways, err := GetServerPeers(node.MacAddress, cfg.Network, cfg.Server.GRPCAddress, node.IsDualStack == "yes", node.IsIngressGateway == "yes", node.IsServer == "yes") + if err != nil && !ncutils.IsEmptyRecord(err) { + ncutils.Log("failed to retrieve peers") + return err + } + + err = initWireguard(&node, privateKey, peers, hasGateway, gateways) + if err != nil { + return err + } + + return nil +} + +// ServerPush - pushes config changes for server checkins/join +func ServerPush(mac string, network string) error { + + var serverNode models.Node + var err error + serverNode, err = GetNode(mac, network) + if err != nil && !ncutils.IsEmptyRecord(err) { + return err + } + serverNode.OS = runtime.GOOS + serverNode.SetLastCheckIn() + err = serverNode.Update(&serverNode) + return err +} + +func GetServerPeers(macaddress string, network string, server string, dualstack bool, isIngressGateway bool, isServer bool) ([]wgtypes.PeerConfig, bool, []string, error) { + hasGateway := false + var err error + var gateways []string + var peers []wgtypes.PeerConfig + var nodecfg models.Node + var nodes []models.Node // fill above fields from server or client + + nodecfg, err = GetNode(macaddress, network) + if err != nil { + return nil, hasGateway, gateways, err + } + nodes, err = GetPeers(nodecfg) + if err != nil { + return nil, hasGateway, gateways, err + } + + keepalive := nodecfg.PersistentKeepalive + keepalivedur, err := time.ParseDuration(strconv.FormatInt(int64(keepalive), 10) + "s") + keepaliveserver, err := time.ParseDuration(strconv.FormatInt(int64(5), 10) + "s") + if err != nil { + Log("Issue with format of keepalive value. Please update netconfig: "+err.Error(), 1) + return nil, hasGateway, gateways, err + } + + for _, node := range nodes { + pubkey, err := wgtypes.ParseKey(node.PublicKey) + if err != nil { + Log("error parsing key "+pubkey.String(), 1) + return peers, hasGateway, gateways, err + } + + if nodecfg.PublicKey == node.PublicKey { + continue + } + if nodecfg.Endpoint == node.Endpoint { + if nodecfg.LocalAddress != node.LocalAddress && node.LocalAddress != "" { + node.Endpoint = node.LocalAddress + } else { + continue + } + } + + var peer wgtypes.PeerConfig + var peeraddr = net.IPNet{ + IP: net.ParseIP(node.Address), + Mask: net.CIDRMask(32, 32), + } + var allowedips []net.IPNet + allowedips = append(allowedips, peeraddr) + // handle manually set peers + for _, allowedIp := range node.AllowedIPs { + if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil { + nodeEndpointArr := strings.Split(node.Endpoint, ":") + if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists.. + allowedips = append(allowedips, *ipnet) + } + } else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address { + ipnet := net.IPNet{ + IP: net.ParseIP(allowedIp), + Mask: net.CIDRMask(32, 32), + } + allowedips = append(allowedips, ipnet) + } + } + // handle egress gateway peers + if node.IsEgressGateway == "yes" { + hasGateway = true + ranges := node.EgressGatewayRanges + for _, iprange := range ranges { // go through each cidr for egress gateway + _, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr + if err != nil { + ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1) + continue // if can't parse CIDR + } + nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node + if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node + ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2) + continue // skip adding egress range if overlaps with node's ip + } + if ipnet.Contains(net.ParseIP(nodecfg.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node + ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+nodecfg.LocalAddress+", omitting", 2) + continue // skip adding egress range if overlaps with node's local ip + } + gateways = append(gateways, iprange) + if err != nil { + Log("ERROR ENCOUNTERED SETTING GATEWAY", 1) + } else { + allowedips = append(allowedips, *ipnet) + } + } + } + if node.Address6 != "" && dualstack { + var addr6 = net.IPNet{ + IP: net.ParseIP(node.Address6), + Mask: net.CIDRMask(128, 128), + } + allowedips = append(allowedips, addr6) + } + if nodecfg.IsServer == "yes" && !(node.IsServer == "yes") { + peer = wgtypes.PeerConfig{ + PublicKey: pubkey, + PersistentKeepaliveInterval: &keepaliveserver, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } else if keepalive != 0 { + peer = wgtypes.PeerConfig{ + PublicKey: pubkey, + PersistentKeepaliveInterval: &keepalivedur, + Endpoint: &net.UDPAddr{ + IP: net.ParseIP(node.Endpoint), + Port: int(node.ListenPort), + }, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } else { + peer = wgtypes.PeerConfig{ + PublicKey: pubkey, + Endpoint: &net.UDPAddr{ + IP: net.ParseIP(node.Endpoint), + Port: int(node.ListenPort), + }, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + } + peers = append(peers, peer) + } + if isIngressGateway { + extPeers, err := GetServerExtPeers(macaddress, network, server, dualstack) + if err == nil { + peers = append(peers, extPeers...) + } else { + Log("ERROR RETRIEVING EXTERNAL PEERS ON SERVER", 1) + } + } + return peers, hasGateway, gateways, err +} + +// GetServerExtPeers - gets the extpeers for a client +func GetServerExtPeers(macaddress string, network string, server string, dualstack bool) ([]wgtypes.PeerConfig, error) { + var peers []wgtypes.PeerConfig + var nodecfg models.Node + var extPeers []models.Node + var err error + // fill above fields from either client or server + + // fill extPeers with server side logic + nodecfg, err = GetNode(macaddress, network) + if err != nil { + return nil, err + } + var tempPeers []models.ExtPeersResponse + tempPeers, err = GetExtPeersList(nodecfg.MacAddress, nodecfg.Network) + if err != nil { + return nil, err + } + for i := 0; i < len(tempPeers); i++ { + extPeers = append(extPeers, models.Node{ + Address: tempPeers[i].Address, + Address6: tempPeers[i].Address6, + Endpoint: tempPeers[i].Endpoint, + PublicKey: tempPeers[i].PublicKey, + PersistentKeepalive: tempPeers[i].KeepAlive, + ListenPort: tempPeers[i].ListenPort, + LocalAddress: tempPeers[i].LocalAddress, + }) + } + for _, extPeer := range extPeers { + pubkey, err := wgtypes.ParseKey(extPeer.PublicKey) + if err != nil { + log.Println("error parsing key") + return peers, err + } + + if nodecfg.PublicKey == extPeer.PublicKey { + continue + } + + var peer wgtypes.PeerConfig + var peeraddr = net.IPNet{ + IP: net.ParseIP(extPeer.Address), + Mask: net.CIDRMask(32, 32), + } + var allowedips []net.IPNet + allowedips = append(allowedips, peeraddr) + + if extPeer.Address6 != "" && dualstack { + var addr6 = net.IPNet{ + IP: net.ParseIP(extPeer.Address6), + Mask: net.CIDRMask(128, 128), + } + allowedips = append(allowedips, addr6) + } + peer = wgtypes.PeerConfig{ + PublicKey: pubkey, + ReplaceAllowedIPs: true, + AllowedIPs: allowedips, + } + peers = append(peers, peer) + } + return peers, err +} diff --git a/logic/serverconf.go b/logic/serverconf.go new file mode 100644 index 00000000..28378aea --- /dev/null +++ b/logic/serverconf.go @@ -0,0 +1,16 @@ +package logic + +import "github.com/gravitl/netmaker/database" + +// StorePrivKey - stores server client WireGuard privatekey if needed +func StorePrivKey(serverID string, privateKey string) error { + return database.Insert(serverID, privateKey, database.SERVERCONF_TABLE_NAME) +} + +func FetchPrivKey(serverID string) (string, error) { + return database.FetchRecord(database.SERVERCONF_TABLE_NAME, serverID) +} + +func RemovePrivKey(serverID string) error { + return database.DeleteRecord(database.SERVERCONF_TABLE_NAME, serverID) +} diff --git a/logic/util.go b/logic/util.go index 9486c4e6..42a59085 100644 --- a/logic/util.go +++ b/logic/util.go @@ -4,9 +4,11 @@ package logic import ( "encoding/base64" "encoding/json" + "log" "strconv" "strings" "time" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/dnslogic" "github.com/gravitl/netmaker/functions" @@ -290,3 +292,10 @@ func setPeerInfo(node models.Node) models.Node { peer.IsPending = node.IsPending return peer } + +func Log(message string, loglevel int) { + log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) + if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 { + log.Println(message) + } +} diff --git a/logic/wireguard.go b/logic/wireguard.go index 8ff5ce12..2be69616 100644 --- a/logic/wireguard.go +++ b/logic/wireguard.go @@ -1,8 +1,21 @@ package logic import ( + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "os/exec" + "strconv" + "strings" + "time" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/local" + "github.com/gravitl/netmaker/netclient/ncutils" "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // GetSystemPeers - gets the server peers @@ -24,3 +37,168 @@ func GetSystemPeers(node *models.Node) (map[string]string, error) { } return peers, nil } + +func initWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig, hasGateway bool, gateways []string) error { + + key, err := wgtypes.ParseKey(privkey) + if err != nil { + return err + } + + wgclient, err := wgctrl.New() + if err != nil { + return err + } + defer wgclient.Close() + + var ifacename string + if node.Interface != "" { + ifacename = node.Interface + } else { + Log("no interface to configure", 0) + } + if node.Address == "" { + Log("no address to configure", 0) + } + + if ncutils.IsKernel() { + setKernelDevice(ifacename, node.Address) + } + + nodeport := int(node.ListenPort) + var conf wgtypes.Config + conf = wgtypes.Config{ + PrivateKey: &key, + ListenPort: &nodeport, + ReplacePeers: true, + Peers: peers, + } + + if !ncutils.IsKernel() { + var newConf string + if node.UDPHolePunch != "yes" { + newConf, _ = ncutils.CreateUserSpaceConf(node.Address, key.String(), strconv.FormatInt(int64(node.ListenPort), 10), node.MTU, node.PersistentKeepalive, peers) + } else { + newConf, _ = ncutils.CreateUserSpaceConf(node.Address, key.String(), "", node.MTU, node.PersistentKeepalive, peers) + } + confPath := ncutils.GetNetclientPathSpecific() + ifacename + ".conf" + ncutils.PrintLog("writing wg conf file to: "+confPath, 1) + err = ioutil.WriteFile(confPath, []byte(newConf), 0644) + if err != nil { + ncutils.PrintLog("error writing wg conf file to "+confPath+": "+err.Error(), 1) + return err + } + // spin up userspace + apply the conf file + var deviceiface string + if ncutils.IsMac() { + deviceiface, err = local.GetMacIface(node.Address) + if err != nil || deviceiface == "" { + deviceiface = ifacename + } + } + d, _ := wgclient.Device(deviceiface) + for d != nil && d.Name == deviceiface { + _ = RemoveConf(ifacename, false) // remove interface first + time.Sleep(time.Second >> 2) + d, _ = wgclient.Device(deviceiface) + } + err = applyWGQuickConf(confPath) + if err != nil { + ncutils.PrintLog("failed to create wireguard interface", 1) + return err + } + } else { + ipExec, err := exec.LookPath("ip") + if err != nil { + return err + } + + _, err = wgclient.Device(ifacename) + if err != nil { + if os.IsNotExist(err) { + fmt.Println("Device does not exist: ") + fmt.Println(err) + } else { + return errors.New("Unknown config error: " + err.Error()) + } + } + + err = wgclient.ConfigureDevice(ifacename, conf) + if err != nil { + if os.IsNotExist(err) { + fmt.Println("Device does not exist: ") + fmt.Println(err) + } else { + fmt.Printf("This is inconvenient: %v", err) + } + } + + if _, err := ncutils.RunCmd(ipExec+" link set down dev "+ifacename, false); err != nil { + ncutils.Log("attempted to remove interface before editing") + return err + } + + if node.PostDown != "" { + runcmds := strings.Split(node.PostDown, "; ") + _ = ncutils.RunCmds(runcmds, true) + } + // set MTU of node interface + if _, err := ncutils.RunCmd(ipExec+" link set mtu "+strconv.Itoa(int(node.MTU))+" up dev "+ifacename, true); err != nil { + ncutils.Log("failed to create interface with mtu " + ifacename) + return err + } + + if node.PostUp != "" { + runcmds := strings.Split(node.PostUp, "; ") + _ = ncutils.RunCmds(runcmds, true) + } + if hasGateway { + for _, gateway := range gateways { + _, _ = ncutils.RunCmd(ipExec+" -4 route add "+gateway+" dev "+ifacename, true) + } + } + if node.Address6 != "" && node.IsDualStack == "yes" { + log.Println("[netclient] adding address: "+node.Address6, 1) + _, _ = ncutils.RunCmd(ipExec+" address add dev "+ifacename+" "+node.Address6+"/64", true) + } + } + + return err +} + +// RemoveConf - removes a configuration for a given WireGuard interface +func RemoveConf(iface string, printlog bool) error { + var err error + confPath := ncutils.GetNetclientPathSpecific() + iface + ".conf" + err = removeWGQuickConf(confPath, printlog) + return err +} + +// == Private Methods == + +func setKernelDevice(ifacename string, address string) error { + ipExec, err := exec.LookPath("ip") + if err != nil { + return err + } + + _, _ = ncutils.RunCmd("ip link delete dev "+ifacename, false) + _, _ = ncutils.RunCmd(ipExec+" link add dev "+ifacename+" type wireguard", true) + _, _ = ncutils.RunCmd(ipExec+" address add dev "+ifacename+" "+address+"/24", true) + + return nil +} + +func applyWGQuickConf(confPath string) error { + if _, err := ncutils.RunCmd("wg-quick up "+confPath, true); err != nil { + return err + } + return nil +} + +func removeWGQuickConf(confPath string, printlog bool) error { + if _, err := ncutils.RunCmd("wg-quick down "+confPath, printlog); err != nil { + return err + } + return nil +} diff --git a/netclient/command/commands.go b/netclient/command/commands.go index 4d7628f4..ae9e58e9 100644 --- a/netclient/command/commands.go +++ b/netclient/command/commands.go @@ -25,8 +25,11 @@ var ( func Join(cfg config.ClientConfig, privateKey string) error { - err := functions.JoinNetwork(cfg, privateKey) - + var err error + err = functions.JoinNetwork(cfg, privateKey) + if err != nil && cfg.Node.IsServer != "yes" { // make sure server side is cleaned up + return err + } if err != nil && !cfg.DebugJoin { if !strings.Contains(err.Error(), "ALREADY_INSTALLED") { ncutils.PrintLog("error installing: "+err.Error(), 1) diff --git a/netclient/functions/checkin.go b/netclient/functions/checkin.go index 74069ef6..aa6eff12 100644 --- a/netclient/functions/checkin.go +++ b/netclient/functions/checkin.go @@ -259,6 +259,7 @@ func Pull(network string, manual bool) (*models.Node, error) { // Push - pushes current client configuration to server func Push(network string) error { + cfg, err := config.ReadConfig(network) if err != nil { return err @@ -268,58 +269,52 @@ func Push(network string) error { postnode.OS = runtime.GOOS postnode.SetLastCheckIn() - if postnode.IsServer != "yes" { // handle client side - var header metadata.MD - var wcclient nodepb.NodeServiceClient - conn, err := grpc.Dial(cfg.Server.GRPCAddress, - ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) - if err != nil { - ncutils.PrintLog("Cant dial GRPC server: "+err.Error(), 1) - return err - } - defer conn.Close() - wcclient = nodepb.NewNodeServiceClient(conn) + var header metadata.MD + var wcclient nodepb.NodeServiceClient + conn, err := grpc.Dial(cfg.Server.GRPCAddress, + ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) + if err != nil { + ncutils.PrintLog("Cant dial GRPC server: "+err.Error(), 1) + return err + } + defer conn.Close() + wcclient = nodepb.NewNodeServiceClient(conn) - ctx, err := auth.SetJWT(wcclient, network) - if err != nil { - ncutils.PrintLog("Failed to authenticate with server: "+err.Error(), 1) - return err - } - if postnode.IsPending != "yes" { - privateKey, err := wireguard.RetrievePrivKey(network) - if err != nil { - return err - } - privateKeyWG, err := wgtypes.ParseKey(privateKey) - if err != nil { - return err - } - if postnode.PublicKey != privateKeyWG.PublicKey().String() { - postnode.PublicKey = privateKeyWG.PublicKey().String() - } - } - nodeData, err := json.Marshal(&postnode) + ctx, err := auth.SetJWT(wcclient, network) + if err != nil { + ncutils.PrintLog("Failed to authenticate with server: "+err.Error(), 1) + return err + } + if postnode.IsPending != "yes" { + privateKey, err := wireguard.RetrievePrivKey(network) if err != nil { return err } + privateKeyWG, err := wgtypes.ParseKey(privateKey) + if err != nil { + return err + } + if postnode.PublicKey != privateKeyWG.PublicKey().String() { + postnode.PublicKey = privateKeyWG.PublicKey().String() + } + } + nodeData, err := json.Marshal(&postnode) + if err != nil { + return err + } - req := &nodepb.Object{ - Data: string(nodeData), - Type: nodepb.NODE_TYPE, - Metadata: "", - } - data, err := wcclient.UpdateNode(ctx, req, grpc.Header(&header)) - if err != nil { - return err - } - err = json.Unmarshal([]byte(data.Data), &postnode) - if err != nil { - return err - } - } else { - if err = postnode.Update(&postnode); err != nil { - return err - } + req := &nodepb.Object{ + Data: string(nodeData), + Type: nodepb.NODE_TYPE, + Metadata: "", + } + data, err := wcclient.UpdateNode(ctx, req, grpc.Header(&header)) + if err != nil { + return err + } + err = json.Unmarshal([]byte(data.Data), &postnode) + if err != nil { + return err } err = config.ModConfig(&postnode) return err diff --git a/netclient/functions/join.go b/netclient/functions/join.go index ee5d3e63..3207d8b2 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -8,7 +8,6 @@ import ( "log" nodepb "github.com/gravitl/netmaker/grpc" - "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config" @@ -24,29 +23,30 @@ import ( // JoinNetwork - helps a client join a network func JoinNetwork(cfg config.ClientConfig, privateKey string) error { - hasnet := local.HasNetwork(cfg.Network) - if hasnet { - err := errors.New("ALREADY_INSTALLED. Netclient appears to already be installed for " + cfg.Network + ". To re-install, please remove by executing 'sudo netclient leave -n " + cfg.Network + "'. Then re-run the install command.") - return err - } - - err := config.Write(&cfg, cfg.Network) - if err != nil { - return err - } - if cfg.Node.Network == "" { return errors.New("no network provided") } + var err error + if cfg.Node.IsServer != "yes" { + if local.HasNetwork(cfg.Network) { + err := errors.New("ALREADY_INSTALLED. Netclient appears to already be installed for " + cfg.Network + ". To re-install, please remove by executing 'sudo netclient leave -n " + cfg.Network + "'. Then re-run the install command.") + return err + } + err = config.Write(&cfg, cfg.Network) + if err != nil { + return err + } + if cfg.Node.Password == "" { + cfg.Node.Password = ncutils.GenPass() + } + auth.StoreSecret(cfg.Node.Password, cfg.Node.Network) + } + if cfg.Node.LocalRange != "" && cfg.Node.LocalAddress == "" { log.Println("local vpn, getting local address from range: " + cfg.Node.LocalRange) cfg.Node.LocalAddress = getLocalIP(cfg.Node) } - if cfg.Node.Password == "" { - cfg.Node.Password = ncutils.GenPass() - } - auth.StoreSecret(cfg.Node.Password, cfg.Node.Network) // set endpoint if blank. set to local if local net, retrieve from function if not if cfg.Node.Endpoint == "" { @@ -140,19 +140,6 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { if err = json.Unmarshal([]byte(nodeData), &node); err != nil { return err } - } else { // handle server side node creation - ncutils.Log("adding a server instance on network " + postnode.Network) - if err = config.ModConfig(postnode); err != nil { - return err - } - node, err = logic.CreateNode(*postnode, cfg.Network) - if err != nil { - return err - } - err = logic.SetNetworkNodesLastModified(node.Network) - if err != nil { - return err - } } // get free port based on returned default listen port @@ -169,28 +156,26 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { } node.Endpoint = node.LocalAddress } - - err = config.ModConfig(&node) - if err != nil { - return err - } - - err = wireguard.StorePrivKey(privateKey, cfg.Network) - if err != nil { - return err - } - - // pushing any local changes to server before starting wireguard - err = Push(cfg.Network) - if err != nil { - return err - } - - if node.IsPending == "yes" { - ncutils.Log("Node is marked as PENDING.") - ncutils.Log("Awaiting approval from Admin before configuring WireGuard.") - if cfg.Daemon != "off" { - return daemon.InstallDaemon(cfg) + if node.IsServer != "yes" { // == handle client side == + err = config.ModConfig(&node) + if err != nil { + return err + } + err = wireguard.StorePrivKey(privateKey, cfg.Network) + if err != nil { + return err + } + if node.IsPending == "yes" { + ncutils.Log("Node is marked as PENDING.") + ncutils.Log("Awaiting approval from Admin before configuring WireGuard.") + if cfg.Daemon != "off" { + return daemon.InstallDaemon(cfg) + } + } + // pushing any local changes to server before starting wireguard + err = Push(cfg.Network) + if err != nil { + return err } } diff --git a/netclient/server/grpc.go b/netclient/server/grpc.go index 9816cc71..60019341 100644 --- a/netclient/server/grpc.go +++ b/netclient/server/grpc.go @@ -9,7 +9,6 @@ import ( "time" nodepb "github.com/gravitl/netmaker/grpc" - "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config" @@ -73,73 +72,21 @@ func CheckIn(network string) (*models.Node, error) { return &node, err } -/* -func RemoveNetwork(network string) error { - //need to implement checkin on server side - cfg, err := config.ReadConfig(network) - if err != nil { - return err - } - servercfg := cfg.Server - node := cfg.Node - log.Println("Deleting remote node with MAC: " + node.MacAddress) - - var wcclient nodepb.NodeServiceClient - conn, err := grpc.Dial(cfg.Server.GRPCAddress, - ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) - if err != nil { - log.Printf("Unable to establish client connection to "+servercfg.GRPCAddress+": %v", err) - //return err - } else { - wcclient = nodepb.NewNodeServiceClient(conn) - ctx, err := auth.SetJWT(wcclient, network) - if err != nil { - //return err - log.Printf("Failed to authenticate: %v", err) - } else { - - var header metadata.MD - - _, err = wcclient.DeleteNode( - ctx, - &nodepb.Object{ - Data: node.MacAddress + "###" + node.Network, - Type: nodepb.STRING_TYPE, - }, - grpc.Header(&header), - ) - if err != nil { - log.Printf("Encountered error deleting node: %v", err) - log.Println(err) - } else { - log.Println("Deleted node " + node.MacAddress) - } - } - } - //err = functions.RemoveLocalInstance(network) - - return err -} -*/ - // GetPeers - gets the peers for a node func GetPeers(macaddress string, network string, server string, dualstack bool, isIngressGateway bool, isServer bool) ([]wgtypes.PeerConfig, bool, []string, error) { hasGateway := false + var err error var gateways []string var peers []wgtypes.PeerConfig - cfg, err := config.ReadConfig(network) - if err != nil { - log.Fatalf("Issue retrieving config for network: "+network+". Please investigate: %v", err) - } - nodecfg := cfg.Node - keepalive := nodecfg.PersistentKeepalive - keepalivedur, err := time.ParseDuration(strconv.FormatInt(int64(keepalive), 10) + "s") - keepaliveserver, err := time.ParseDuration(strconv.FormatInt(int64(5), 10) + "s") - if err != nil { - log.Fatalf("Issue with format of keepalive value. Please update netconfig: %v", err) - } - var nodes []models.Node // fill this either from server or client - if !isServer { // set peers client side + var nodecfg models.Node + var nodes []models.Node // fill above fields from server or client + + if !isServer { // set peers client side + cfg, err := config.ReadConfig(network) + if err != nil { + log.Fatalf("Issue retrieving config for network: "+network+". Please investigate: %v", err) + } + nodecfg = cfg.Node var wcclient nodepb.NodeServiceClient conn, err := grpc.Dial(cfg.Server.GRPCAddress, ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) @@ -173,11 +120,13 @@ func GetPeers(macaddress string, network string, server string, dualstack bool, log.Println("Error unmarshaling data for peers") return nil, hasGateway, gateways, err } - } else { // set peers serverside - nodes, err = logic.GetPeers(nodecfg) - if err != nil { - return nil, hasGateway, gateways, err - } + } + + keepalive := nodecfg.PersistentKeepalive + keepalivedur, err := time.ParseDuration(strconv.FormatInt(int64(keepalive), 10) + "s") + keepaliveserver, err := time.ParseDuration(strconv.FormatInt(int64(5), 10) + "s") + if err != nil { + log.Fatalf("Issue with format of keepalive value. Please update netconfig: %v", err) } for _, node := range nodes { @@ -299,14 +248,18 @@ func GetPeers(macaddress string, network string, server string, dualstack bool, // GetExtPeers - gets the extpeers for a client func GetExtPeers(macaddress string, network string, server string, dualstack bool) ([]wgtypes.PeerConfig, error) { var peers []wgtypes.PeerConfig - - cfg, err := config.ReadConfig(network) - if err != nil { - log.Fatalf("Issue retrieving config for network: "+network+". Please investigate: %v", err) - } - nodecfg := cfg.Node + var nodecfg models.Node var extPeers []models.Node + var err error + // fill above fields from either client or server + if nodecfg.IsServer != "yes" { // fill extPeers with client side logic + var cfg *config.ClientConfig + cfg, err = config.ReadConfig(network) + if err != nil { + log.Fatalf("Issue retrieving config for network: "+network+". Please investigate: %v", err) + } + nodecfg = cfg.Node var wcclient nodepb.NodeServiceClient conn, err := grpc.Dial(cfg.Server.GRPCAddress, @@ -339,22 +292,6 @@ func GetExtPeers(macaddress string, network string, server string, dualstack boo if err = json.Unmarshal([]byte(responseObject.Data), &extPeers); err != nil { return nil, err } - } else { // fill extPeers with server side logic - tempPeers, err := logic.GetExtPeersList(nodecfg.MacAddress, nodecfg.Network) - if err != nil { - return nil, err - } - for i := 0; i < len(tempPeers); i++ { - extPeers = append(extPeers, models.Node{ - Address: tempPeers[i].Address, - Address6: tempPeers[i].Address6, - Endpoint: tempPeers[i].Endpoint, - PublicKey: tempPeers[i].PublicKey, - PersistentKeepalive: tempPeers[i].KeepAlive, - ListenPort: tempPeers[i].ListenPort, - LocalAddress: tempPeers[i].LocalAddress, - }) - } } for _, extPeer := range extPeers { pubkey, err := wgtypes.ParseKey(extPeer.PublicKey) diff --git a/netclient/wireguard/unix.go b/netclient/wireguard/unix.go index b84aecf4..ccb8115e 100644 --- a/netclient/wireguard/unix.go +++ b/netclient/wireguard/unix.go @@ -66,8 +66,9 @@ func RemoveWGQuickConf(confPath string, printlog bool) error { // StorePrivKey - stores wg priv key on disk locally func StorePrivKey(key string, network string) error { + var err error d1 := []byte(key) - err := ioutil.WriteFile(ncutils.GetNetclientPathSpecific()+"wgkey-"+network, d1, 0644) + err = ioutil.WriteFile(ncutils.GetNetclientPathSpecific()+"wgkey-"+network, d1, 0644) return err }