diff --git a/controllers/node.go b/controllers/node.go index e11887be..3a7bf974 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -31,7 +31,7 @@ func nodeHandlers(r *mux.Router) { r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST") - r.HandleFunc("/api/nodes/{network}", createNode).Methods("POST") + r.HandleFunc("/api/nodes/{network}", nodeauth(http.HandlerFunc(createNode))).Methods("POST") r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET") r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods("POST") } @@ -131,6 +131,51 @@ func authenticate(response http.ResponseWriter, request *http.Request) { } } +// auth middleware for api calls from nodes where node is has not yet joined the server (register, join) +func nodeauth(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + bearerToken := r.Header.Get("Authorization") + var tokenSplit = strings.Split(bearerToken, " ") + var token = "" + if len(tokenSplit) < 2 { + errorResponse := models.ErrorResponse{ + Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", + } + returnErrorResponse(w, r, errorResponse) + return + } else { + token = tokenSplit[1] + } + found := false + networks, err := logic.GetNetworks() + if err != nil { + logger.Log(0, "no networks", err.Error()) + errorResponse := models.ErrorResponse{ + Code: http.StatusNotFound, Message: "no networks", + } + returnErrorResponse(w, r, errorResponse) + return + } + for _, network := range networks { + for _, key := range network.AccessKeys { + if key.Value == token { + found = true + break + } + } + } + if !found { + logger.Log(0, "valid access key not found") + errorResponse := models.ErrorResponse{ + Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.", + } + returnErrorResponse(w, r, errorResponse) + return + } + next.ServeHTTP(w, r) + } +} + //The middleware for most requests to the API //They all pass through here first //This will validate the JWT (or check for master token) diff --git a/controllers/server.go b/controllers/server.go index 14399605..7fa824ae 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -22,7 +22,7 @@ func serverHandlers(r *mux.Router) { // r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods("POST") r.HandleFunc("/api/server/getconfig", securityCheckServer(false, http.HandlerFunc(getConfig))).Methods("GET") r.HandleFunc("/api/server/removenetwork/{network}", securityCheckServer(true, http.HandlerFunc(removeNetwork))).Methods("DELETE") - r.HandleFunc("/api/server/register", http.HandlerFunc(register)).Methods("POST") + r.HandleFunc("/api/server/register", nodeauth(http.HandlerFunc(register))).Methods("POST") } //Security check is middleware for every function and just checks to make sure that its the master calling @@ -115,18 +115,6 @@ func getConfig(w http.ResponseWriter, r *http.Request) { func register(w http.ResponseWriter, r *http.Request) { logger.Log(2, "processing registration request") w.Header().Set("Content-Type", "application/json") - bearerToken := r.Header.Get("Authorization") - var tokenSplit = strings.Split(bearerToken, " ") - var token = "" - if len(tokenSplit) < 2 { - errorResponse := models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: "W1R3: You are unauthorized to access this endpoint.", - } - returnErrorResponse(w, r, errorResponse) - return - } else { - token = tokenSplit[1] - } //decode body var request config.RegisterRequest if err := json.NewDecoder(r.Body).Decode(&request); err != nil { @@ -137,32 +125,6 @@ func register(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, errorResponse) return } - found := false - networks, err := logic.GetNetworks() - if err != nil { - logger.Log(0, "no networks", err.Error()) - errorResponse := models.ErrorResponse{ - Code: http.StatusNotFound, Message: "no networks", - } - returnErrorResponse(w, r, errorResponse) - return - } - for _, network := range networks { - for _, key := range network.AccessKeys { - if key.Value == token { - found = true - break - } - } - } - if !found { - logger.Log(0, "valid access key not found") - errorResponse := models.ErrorResponse{ - Code: http.StatusUnauthorized, Message: "You are unauthorized to access this endpoint.", - } - returnErrorResponse(w, r, errorResponse) - return - } cert, ca, err := genCerts(&request.Key, &request.CommonName) if err != nil { logger.Log(0, "failed to generater certs ", err.Error()) diff --git a/netclient/functions/common.go b/netclient/functions/common.go index a7a38de9..27fa60d2 100644 --- a/netclient/functions/common.go +++ b/netclient/functions/common.go @@ -352,11 +352,16 @@ func Authenticate(cfg *config.ClientConfig) (string, error) { if err != nil { return "", err } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + bodybytes, _ := ioutil.ReadAll(response.Body) + return "", fmt.Errorf("failed to authenticate %s %s", response.Status, string(bodybytes)) + } resp := models.SuccessResponse{} if err := json.NewDecoder(response.Body).Decode(&resp); err != nil { - return "", err + return "", fmt.Errorf("error decoding respone %w", err) } - tokenData := (resp.Response.(map[string]interface{})) + tokenData := resp.Response.(map[string]interface{}) token := tokenData["AuthToken"] return token.(string), nil } diff --git a/netclient/functions/join.go b/netclient/functions/join.go index a0473244..52980a62 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -1,15 +1,16 @@ package functions import ( - "context" "crypto/rand" "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "log" + "net/http" "runtime" - nodepb "github.com/gravitl/netmaker/grpc" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -18,11 +19,9 @@ import ( "github.com/gravitl/netmaker/netclient/daemon" "github.com/gravitl/netmaker/netclient/local" "github.com/gravitl/netmaker/netclient/ncutils" - "github.com/gravitl/netmaker/netclient/server" "github.com/gravitl/netmaker/netclient/wireguard" "golang.org/x/crypto/nacl/box" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc" ) // JoinNetwork - helps a client join a network @@ -123,45 +122,28 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string, iscomms bool) erro } // make sure name is appropriate, if not, give blank name cfg.Node.Name = formatName(cfg.Node) - // differentiate between client/server here - var node = models.Node{ - Password: cfg.Node.Password, - Address: cfg.Node.Address, - Address6: cfg.Node.Address6, - ID: cfg.Node.ID, - MacAddress: cfg.Node.MacAddress, - AccessKey: cfg.Server.AccessKey, - IsStatic: cfg.Node.IsStatic, - //Roaming: cfg.Node.Roaming, - Network: cfg.Network, - ListenPort: cfg.Node.ListenPort, - PostUp: cfg.Node.PostUp, - PostDown: cfg.Node.PostDown, - PersistentKeepalive: cfg.Node.PersistentKeepalive, - LocalAddress: cfg.Node.LocalAddress, - Interface: cfg.Node.Interface, - PublicKey: cfg.Node.PublicKey, - DNSOn: cfg.Node.DNSOn, - Name: cfg.Node.Name, - Endpoint: cfg.Node.Endpoint, - UDPHolePunch: cfg.Node.UDPHolePunch, - TrafficKeys: cfg.Node.TrafficKeys, - OS: runtime.GOOS, - Version: ncutils.Version, - } - - logger.Log(0, "joining "+cfg.Network+" at "+cfg.Server.GRPCAddress) - var wcclient nodepb.NodeServiceClient - - conn, err := grpc.Dial(cfg.Server.GRPCAddress, - ncutils.GRPCRequestOpts(cfg.Server.GRPCSSL)) - + cfg.Node.OS = runtime.GOOS + cfg.Node.Version = ncutils.Version + cfg.Node.AccessKey = cfg.Server.AccessKey + //not sure why this is needed ... setnode defaults should take care of this on server + cfg.Node.IPForwarding = "yes" + logger.Log(0, "joining "+cfg.Network+" at "+cfg.Server.API) + url := "https://" + cfg.Server.API + "/api/nodes/" + cfg.Network + response, err := API(cfg.Node, http.MethodPost, url, cfg.Server.AccessKey) if err != nil { - log.Fatalf("Unable to establish client connection to "+cfg.Server.GRPCAddress+": %v", err) + return fmt.Errorf("error creating node %w", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + bodybytes, _ := io.ReadAll(response.Body) + return fmt.Errorf("error creating node %s %s", response.Status, string(bodybytes)) + } + var node models.Node + if err := json.NewDecoder(response.Body).Decode(&node); err != nil { + //not sure the next line will work as response.Body probably needs to be reset before it can be read again + bodybytes, _ := ioutil.ReadAll(response.Body) + return fmt.Errorf("error decoding node from server %w %s", err, string(bodybytes)) } - defer conn.Close() - wcclient = nodepb.NewNodeServiceClient(conn) - // safety check. If returned node from server is local, but not currently configured as local, set to local addr if cfg.Node.IsLocal != "yes" && node.IsLocal == "yes" && node.LocalRange != "" { node.LocalAddress, err = ncutils.GetLocalIP(node.LocalRange) @@ -186,35 +168,11 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string, iscomms bool) erro return daemon.InstallDaemon(cfg) } } - data, err := json.Marshal(&node) - if err != nil { - return err - } - // Create node on server - res, err := wcclient.CreateNode( - context.TODO(), - &nodepb.Object{ - Data: string(data), - Type: nodepb.NODE_TYPE, - }, - ) - if err != nil { - return err - } logger.Log(1, "node created on remote server...updating configs") - // keep track of the old listenport value oldListenPort := node.ListenPort - - nodeData := res.Data - if err = json.Unmarshal([]byte(nodeData), &node); err != nil { - return err - } - cfg.Node = node - setListenPort(oldListenPort, cfg) - err = config.ModConfig(&cfg.Node) if err != nil { return err @@ -223,45 +181,19 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string, iscomms bool) erro if err = config.SaveBackup(node.Network); err != nil { logger.Log(0, "failed to make backup, node will not auto restore if config is corrupted") } - - logger.Log(0, "retrieving peers") - peers, hasGateway, gateways, err := server.GetPeers(node.MacAddress, cfg.Network, cfg.Server.GRPCAddress, node.IsDualStack == "yes", node.IsIngressGateway == "yes", node.IsServer == "yes") - if err != nil && !ncutils.IsEmptyRecord(err) { - logger.Log(0, "failed to retrieve peers") - return err - } - logger.Log(0, "starting wireguard") - err = wireguard.InitWireguard(&node, privateKey, peers, hasGateway, gateways, false) + err = wireguard.InitWireguard(&node, privateKey, []wgtypes.PeerConfig{}, false, []string{}, false) if err != nil { return err } - // if node.DNSOn == "yes" { - // for _, server := range node.NetworkSettings.DefaultServerAddrs { - // if server.IsLeader { - // go func() { - // if !local.SetDNSWithRetry(node, server.Address) { - // cfg.Node.DNSOn = "no" - // var currentCommsCfg = getCommsCfgByNode(&cfg.Node) - // PublishNodeUpdate(¤tCommsCfg, &cfg) - // } - // }() - // break - // } - // } - // } - - if !iscomms { - if cfg.Daemon != "off" { - err = daemon.InstallDaemon(cfg) - } + if cfg.Daemon != "off" { + err = daemon.InstallDaemon(cfg) if err != nil { return err } else { daemon.Restart() } } - return nil }