diff --git a/Dockerfile-quick b/Dockerfile-quick new file mode 100644 index 00000000..08f5bc74 --- /dev/null +++ b/Dockerfile-quick @@ -0,0 +1,14 @@ +#first stage - builder +FROM alpine:3.15.2 +ARG version +WORKDIR /app +COPY ./netmaker /root/netmaker +ENV GO111MODULE=auto + +# add a c lib +RUN apk add gcompat iptables wireguard-tools +# set the working directory +WORKDIR /root/ +RUN mkdir -p /etc/netclient/config +EXPOSE 8081 +ENTRYPOINT ["./netmaker"] diff --git a/auth/auth.go b/auth/auth.go index bd7bcd3d..cdc2581b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -9,6 +9,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro/netcache" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" @@ -27,8 +28,19 @@ const ( oidc_provider_name = "oidc" verify_user = "verifyuser" auth_key = "netmaker_auth" + user_signin_length = 16 + node_signin_length = 64 ) +// OAuthUser - generic OAuth strategy user +type OAuthUser struct { + Name string `json:"name" bson:"name"` + Email string `json:"email" bson:"email"` + Login string `json:"login" bson:"login"` + UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"` + AccessToken string `json:"accesstoken" bson:"accesstoken"` +} + var auth_provider *oauth2.Config func getCurrentAuthFunctions() map[string]interface{} { @@ -94,7 +106,14 @@ func HandleAuthCallback(w http.ResponseWriter, r *http.Request) { if functions == nil { return } - functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r) + state, _ := getStateAndCode(r) + _, err := netcache.Get(state) // if in netcache proceeed with node registration login + if err == nil || len(state) == node_signin_length || (err != nil && strings.Contains(err.Error(), "expired")) { + logger.Log(0, "proceeding with node SSO callback") + HandleNodeSSOCallback(w, r) + } else { // handle normal login + functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r) + } } // swagger:route GET /api/oauth/login nodes HandleAuthLogin @@ -197,3 +216,35 @@ func fetchPassValue(newValue string) (string, error) { } return string(b64CurrentValue), nil } + +func getStateAndCode(r *http.Request) (string, string) { + var state, code string + if r.FormValue("state") != "" && r.FormValue("code") != "" { + state = r.FormValue("state") + code = r.FormValue("code") + } else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" { + state = r.URL.Query().Get("state") + code = r.URL.Query().Get("code") + } + + return state, code +} + +func (user *OAuthUser) getUserName() string { + var userName string + if user.Email != "" { + userName = user.Email + } else if user.Login != "" { + userName = user.Login + } else if user.UserPrincipalName != "" { + userName = user.UserPrincipalName + } else if user.Name != "" { + userName = user.Name + } + return userName +} + +func isStateCached(state string) bool { + _, err := netcache.Get(state) + return err == nil || strings.Contains(err.Error(), "expired") +} diff --git a/auth/azure-ad.go b/auth/azure-ad.go index b2931b50..ac150d60 100644 --- a/auth/azure-ad.go +++ b/auth/azure-ad.go @@ -23,11 +23,6 @@ var azure_ad_functions = map[string]interface{}{ verify_user: verifyAzureUser, } -type azureOauthUser struct { - UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"` - AccessToken string `json:"accesstoken" bson:"accesstoken"` -} - // == handle azure ad authentication here == func initAzureAD(redirectURL string, clientID string, clientSecret string) { @@ -41,7 +36,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) { } func handleAzureLogin(w http.ResponseWriter, r *http.Request) { - var oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(user_signin_length) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -61,7 +56,8 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) { func handleAzureCallback(w http.ResponseWriter, r *http.Request) { - var content, err = getAzureUserInfo(r.FormValue("state"), r.FormValue("code")) + var rState, rCode = getStateAndCode(r) + var content, err = getAzureUserInfo(rState, rCode) if err != nil { logger.Log(1, "error when getting user info from azure:", err.Error()) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) @@ -93,9 +89,9 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect) } -func getAzureUserInfo(state string, code string) (*azureOauthUser, error) { +func getAzureUserInfo(state string, code string) (*OAuthUser, error) { oauth_state_string, isValid := logic.IsStateValid(state) - if !isValid || state != oauth_state_string { + if (!isValid || state != oauth_state_string) && !isStateCached(state) { return nil, fmt.Errorf("invalid oauth state") } var token, err = auth_provider.Exchange(context.Background(), code) @@ -121,7 +117,7 @@ func getAzureUserInfo(state string, code string) (*azureOauthUser, error) { if err != nil { return nil, fmt.Errorf("failed reading response body: %s", err.Error()) } - var userInfo = &azureOauthUser{} + var userInfo = &OAuthUser{} if err = json.Unmarshal(contents, userInfo); err != nil { return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) } diff --git a/auth/github.go b/auth/github.go index 2bbdfdea..b333f50c 100644 --- a/auth/github.go +++ b/auth/github.go @@ -23,11 +23,6 @@ var github_functions = map[string]interface{}{ verify_user: verifyGithubUser, } -type githubOauthUser struct { - Login string `json:"login" bson:"login"` - AccessToken string `json:"accesstoken" bson:"accesstoken"` -} - // == handle github authentication here == func initGithub(redirectURL string, clientID string, clientSecret string) { @@ -41,7 +36,7 @@ func initGithub(redirectURL string, clientID string, clientSecret string) { } func handleGithubLogin(w http.ResponseWriter, r *http.Request) { - var oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(user_signin_length) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -61,7 +56,8 @@ func handleGithubLogin(w http.ResponseWriter, r *http.Request) { func handleGithubCallback(w http.ResponseWriter, r *http.Request) { - var content, err = getGithubUserInfo(r.URL.Query().Get("state"), r.URL.Query().Get("code")) + var rState, rCode = getStateAndCode(r) + var content, err = getGithubUserInfo(rState, rCode) if err != nil { logger.Log(1, "error when getting user info from github:", err.Error()) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) @@ -93,10 +89,10 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect) } -func getGithubUserInfo(state string, code string) (*githubOauthUser, error) { +func getGithubUserInfo(state string, code string) (*OAuthUser, error) { oauth_state_string, isValid := logic.IsStateValid(state) - if !isValid || state != oauth_state_string { - return nil, fmt.Errorf("invalid OAuth state") + if (!isValid || state != oauth_state_string) && !isStateCached(state) { + return nil, fmt.Errorf("invalid oauth state") } var token, err = auth_provider.Exchange(context.Background(), code) if err != nil { @@ -125,7 +121,7 @@ func getGithubUserInfo(state string, code string) (*githubOauthUser, error) { if err != nil { return nil, fmt.Errorf("failed reading response body: %s", err.Error()) } - var userInfo = &githubOauthUser{} + var userInfo = &OAuthUser{} if err = json.Unmarshal(contents, userInfo); err != nil { return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) } diff --git a/auth/google.go b/auth/google.go index 344c9938..22be2a47 100644 --- a/auth/google.go +++ b/auth/google.go @@ -24,11 +24,6 @@ var google_functions = map[string]interface{}{ verify_user: verifyGoogleUser, } -type googleOauthUser struct { - Email string `json:"email" bson:"email"` - AccessToken string `json:"accesstoken" bson:"accesstoken"` -} - // == handle google authentication here == func initGoogle(redirectURL string, clientID string, clientSecret string) { @@ -42,7 +37,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) { } func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { - var oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(user_signin_length) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -62,7 +57,9 @@ func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { - var content, err = getGoogleUserInfo(r.FormValue("state"), r.FormValue("code")) + var rState, rCode = getStateAndCode(r) + + var content, err = getGoogleUserInfo(rState, rCode) if err != nil { logger.Log(1, "error when getting user info from google:", err.Error()) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) @@ -94,10 +91,10 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) } -func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) { +func getGoogleUserInfo(state string, code string) (*OAuthUser, error) { oauth_state_string, isValid := logic.IsStateValid(state) - if !isValid || state != oauth_state_string { - return nil, fmt.Errorf("invalid OAuth state") + if (!isValid || state != oauth_state_string) && !isStateCached(state) { + return nil, fmt.Errorf("invalid oauth state") } var token, err = auth_provider.Exchange(context.Background(), code) if err != nil { @@ -120,7 +117,7 @@ func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) { if err != nil { return nil, fmt.Errorf("failed reading response body: %s", err.Error()) } - var userInfo = &googleOauthUser{} + var userInfo = &OAuthUser{} if err = json.Unmarshal(contents, userInfo); err != nil { return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) } diff --git a/auth/nodecallback.go b/auth/nodecallback.go new file mode 100644 index 00000000..d7bfae9c --- /dev/null +++ b/auth/nodecallback.go @@ -0,0 +1,259 @@ +package auth + +import ( + "bytes" + "fmt" + "net/http" + "time" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" + "github.com/gravitl/netmaker/logic/pro/netcache" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" + "github.com/gravitl/netmaker/servercfg" +) + +var ( + redirectUrl string +) + +// HandleNodeSSOCallback handles the callback from the sso endpoint +// It is the analogue of auth.handleNodeSSOCallback but takes care of the end point flow +// Retrieves the mkey from the state cache and adds the machine to the users email namespace +// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities +// TODO: Add groups information from OIDC tokens into machine HostInfo +// Listens in /oidc/callback. +func HandleNodeSSOCallback(w http.ResponseWriter, r *http.Request) { + + var functions = getCurrentAuthFunctions() + if functions == nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("bad conf")) + logger.Log(0, "Missing Oauth config in HandleNodeSSOCallback") + return + } + + state, code := getStateAndCode(r) + + var userClaims, err = functions[get_user_info].(func(string, string) (*OAuthUser, error))(state, code) + if err != nil { + logger.Log(0, "error when getting user info from callback:", err.Error()) + http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + + if code == "" || state == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Wrong params")) + logger.Log(0, "Missing params in HandleSSOCallback") + return + } + + // all responses should be in html format from here on out + w.Header().Add("content-type", "text/html; charset=utf-8") + + // retrieve machinekey from state cache + reqKeyIf, machineKeyFoundErr := netcache.Get(state) + if machineKeyFoundErr != nil { + logger.Log(0, "requested machine state key expired before authorisation completed -", err.Error()) + reqKeyIf = &netcache.CValue{ + Network: "invalid", + Value: state, + Pass: "", + User: "netmaker", + Expiration: time.Now(), + } + response := returnErrTemplate("", "requested machine state key expired before authorisation completed", state, reqKeyIf) + w.WriteHeader(http.StatusInternalServerError) + w.Write(response) + return + } + + user, err := isUserIsAllowed(userClaims.getUserName(), reqKeyIf.Network, true) + if err != nil { + logger.Log(0, "error occurred during SSO node join for user", userClaims.getUserName(), "on network", reqKeyIf.Network, "-", err.Error()) + response := returnErrTemplate(user.UserName, err.Error(), state, reqKeyIf) + w.WriteHeader(http.StatusNotAcceptable) + w.Write(response) + return + } + + logger.Log(1, "registering new node for user:", user.UserName, "on network", reqKeyIf.Network) + + // Send OK to user in the browser + var response bytes.Buffer + if err := ssoCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{ + User: userClaims.getUserName(), + Verb: "Authenticated", + }); err != nil { + logger.Log(0, "Could not render SSO callback template ", err.Error()) + response := returnErrTemplate(user.UserName, "Could not render SSO callback template", state, reqKeyIf) + w.WriteHeader(http.StatusInternalServerError) + w.Write(response) + + } else { + w.WriteHeader(http.StatusOK) + w.Write(response.Bytes()) + } + + // Need to send access key to the client + logger.Log(1, "Handling new machine addition to network", + reqKeyIf.Network, "with key", + reqKeyIf.Value, " identity:", userClaims.getUserName(), "claims:", fmt.Sprintf("%+v", userClaims)) + + var answer string + // The registation logic is starting here: + // we request access key with 1 use for the required network + accessToken, err := requestAccessKey(reqKeyIf.Network, 1, userClaims.getUserName()) + if err != nil { + answer = fmt.Sprintf("Error from the netmaker controller %s", err.Error()) + } else { + answer = fmt.Sprintf("AccessToken: %s", accessToken) + } + logger.Log(0, "Updating the token for the client request ... ") + // Give the user the access token via Pass in the DB + reqKeyIf.Pass = answer + if err = netcache.Set(state, reqKeyIf); err != nil { + logger.Log(0, "machine failed to complete join on network,", reqKeyIf.Network, "-", err.Error()) + return + } +} + +func setNetcache(ncache *netcache.CValue, state string) error { + if ncache == nil { + return fmt.Errorf("cache miss") + } + var err error + if err = netcache.Set(state, ncache); err != nil { + logger.Log(0, "machine failed to complete join on network,", ncache.Network, "-", err.Error()) + } + return err +} + +func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) []byte { + var response bytes.Buffer + ncache.Pass = message + err := ssoErrCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{ + User: uname, + Verb: message, + }) + if err != nil { + return []byte(err.Error()) + } + err = setNetcache(ncache, state) + if err != nil { + return []byte(err.Error()) + } + return response.Bytes() +} + +// RegisterNodeSSO redirects to the IDP for authentication +// Puts machine key in cache so the callback can retrieve it using the oidc state param +// Listens in /oidc/register/:regKey. +func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) { + + logger.Log(1, "RegisterNodeSSO\n") + + vars := mux.Vars(r) + + // machineKeyStr this is not key but state + machineKeyStr := vars["regKey"] + logger.Log(1, "requested key:", machineKeyStr) + + if machineKeyStr == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Wrong params")) + logger.Log(0, "Wrong params ", machineKeyStr) + return + } + + // machineKeyStr this not key but state + authURL := auth_provider.AuthCodeURL(machineKeyStr) + //authURL = authURL + "&connector_id=" + "google" + logger.Log(0, "Redirecting to ", authURL, " for authentication") + + http.Redirect(w, r, authURL, http.StatusSeeOther) + +} + +// == private == +// API to create an access key for a given network with a given name +func requestAccessKey(network string, uses int, name string) (accessKey string, err error) { + + var sAccessKey models.AccessKey + var sNetwork models.Network + + sNetwork, err = logic.GetParentNetwork(network) + if err != nil { + logger.Log(0, "err calling GetParentNetwork API=%s", err.Error()) + return "", fmt.Errorf("internal controller error %s", err.Error()) + } + // If a key already exists, we recreate it. + // @TODO Is that a preferred handling ? We could also trying to re-use. + // can happen if user started log in but did not finish + for _, currentkey := range sNetwork.AccessKeys { + if currentkey.Name == name { + logger.Log(0, "erasing existing AccessKey for: ", name) + err = logic.DeleteKey(currentkey.Name, network) + if err != nil { + logger.Log(0, "err calling CreateAccessKey API ", err.Error()) + return "", fmt.Errorf("key already exists. Contact admin to resolve") + } + break + } + } + // Only one usage is needed - for the next time new access key will be required + // it will be created next time after another IdP approval + sAccessKey.Uses = 1 + sAccessKey.Name = name + + accessToken, err := logic.CreateAccessKey(sAccessKey, sNetwork) + if err != nil { + logger.Log(0, "err calling CreateAccessKey API ", err.Error()) + return "", fmt.Errorf("error from the netmaker controller %s", err.Error()) + } else { + logger.Log(1, "created access key", sAccessKey.Name, "on", network) + } + return accessToken.AccessString, nil +} + +func isUserIsAllowed(username, network string, shouldAddUser bool) (*models.User, error) { + + user, err := logic.GetUser(username) + if err != nil && shouldAddUser { // user must not exist, so try to make one + if err = addUser(username); err != nil { + logger.Log(0, "failed to add user", username, "during a node SSO network join on network", network) + // response := returnErrTemplate(user.UserName, "failed to add user", state, reqKeyIf) + // w.WriteHeader(http.StatusInternalServerError) + // w.Write(response) + return nil, fmt.Errorf("failed to add user to system") + } + logger.Log(0, "user", username, "was added during a node SSO network join on network", network) + user, _ = logic.GetUser(username) + } + + if !user.IsAdmin { // perform check to see if user is allowed to join a node to network + netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(user.UserName)) + if err != nil { + logger.Log(0, "failed to get net user details for user", user.UserName, "during node SSO") + return nil, fmt.Errorf("failed to verify network user") + } + if netUser.AccessLevel != pro.NET_ADMIN { // if user is a net admin on network, good to go + // otherwise, check if they have node access + haven't reached node limit on network + if netUser.AccessLevel == pro.NODE_ACCESS { + if len(netUser.Nodes) >= netUser.NodeLimit { + logger.Log(0, "user", user.UserName, "has reached their node limit on network", network) + return nil, fmt.Errorf("user node limit exceeded") + } + } else { + logger.Log(0, "user", user.UserName, "attempted to access network", network, "via node SSO") + return nil, fmt.Errorf("network user not allowed") + } + } + } + + return &user, nil +} diff --git a/auth/nodesession.go b/auth/nodesession.go new file mode 100644 index 00000000..b848c2b8 --- /dev/null +++ b/auth/nodesession.go @@ -0,0 +1,155 @@ +package auth + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro/netcache" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" + "github.com/gravitl/netmaker/servercfg" +) + +// SessionHandler - called by the HTTP router when user +// is calling netclient with --login-server parameter in order to authenticate +// via SSO mechanism by OAuth2 protocol flow. +// This triggers a session start and it is managed by the flow implmented here and callback +// When this method finishes - the auth flow has finished either OK or by timeout or any other error occured +func SessionHandler(conn *websocket.Conn) { + defer conn.Close() + logger.Log(1, "Running sessionHandler") + + // If reached here we have a session from user to handle... + messageType, message, err := conn.ReadMessage() + if err != nil { + logger.Log(0, "Error during message reading:", err.Error()) + return + } + var loginMessage promodels.LoginMsg + + err = json.Unmarshal(message, &loginMessage) + if err != nil { + logger.Log(0, "Failed to unmarshall data err=", err.Error()) + return + } + logger.Log(1, "SSO node join attempted with info network:", loginMessage.Network, "node identifier:", loginMessage.Mac, "user:", loginMessage.User) + + req := new(netcache.CValue) + req.Value = string(loginMessage.Mac) + req.Network = loginMessage.Network + req.Pass = "" + req.User = "" + // Add any extra parameter provided in the configuration to the Authorize Endpoint request?? + stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length))) + if err := netcache.Set(stateStr, req); err != nil { + logger.Log(0, "Failed to process sso request -", err.Error()) + return + } + // Wait for the user to finish his auth flow... + // TBD: what should be the timeout here ? + timeout := make(chan bool, 1) + answer := make(chan string, 1) + + if loginMessage.User != "" { // handle basic auth + // verify that server supports basic auth, then authorize the request with given credentials + // check if user is allowed to join via node sso + // i.e. user is admin or user has network permissions + if !servercfg.IsBasicAuthEnabled() { + err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled")) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + } + _, err := logic.VerifyAuthRequest(models.UserAuthParams{ + UserName: loginMessage.User, + Password: loginMessage.Password, + }) + if err != nil { + err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User))) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + return + } + user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false) + if err != nil { + err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User))) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + return + } + accessToken, err := requestAccessKey(loginMessage.Network, 1, user.UserName) + if err != nil { + req.Pass = fmt.Sprintf("Error from the netmaker controller %s", err.Error()) + } else { + req.Pass = fmt.Sprintf("AccessToken: %s", accessToken) + } + // Give the user the access token via Pass in the DB + if err = netcache.Set(stateStr, req); err != nil { + logger.Log(0, "machine failed to complete join on network,", loginMessage.Network, "-", err.Error()) + return + } + } else { // handle SSO / OAuth + redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr) + err = conn.WriteMessage(messageType, []byte(redirectUrl)) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + } + + go func() { + for { + cachedReq, err := netcache.Get(stateStr) + if err != nil { + if strings.Contains(err.Error(), "expired") { + logger.Log(0, "timeout occurred while waiting for SSO on network", loginMessage.Network) + timeout <- true + break + } + continue + } else if cachedReq.Pass != "" { + logger.Log(0, "node SSO process completed for user", cachedReq.User, "on network", loginMessage.Network) + answer <- cachedReq.Pass + break + } + time.Sleep(500) // try it 2 times per second to see if auth is completed + } + }() + + select { + case result := <-answer: + // a read from req.answerCh has occurred + err = conn.WriteMessage(messageType, []byte(result)) + if err != nil { + logger.Log(0, "Error during message writing:", err.Error()) + } + case <-timeout: + logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network) + // the read from req.answerCh has timed out + err = conn.WriteMessage(messageType, []byte("Authentication server time out")) + if err != nil { + logger.Log(0, "Error during message writing:", err.Error()) + } + } + // The entry is not needed anymore, but we will let the producer to close it to avoid panic cases + if err = netcache.Del(stateStr); err != nil { + logger.Log(0, "failed to remove node SSO cache entry", err.Error()) + } + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + logger.Log(0, "write close:", err.Error()) + return + } + time.After(time.Second) + close(answer) + close(timeout) +} diff --git a/auth/oidc.go b/auth/oidc.go index 77e26ad9..ad03b1c9 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -26,11 +26,6 @@ var oidc_functions = map[string]interface{}{ var oidc_verifier *oidc.IDTokenVerifier -type OIDCUser struct { - Name string `json:"name" bson:"name"` - Email string `json:"email" bson:"email"` -} - // == handle OIDC authentication here == func initOIDC(redirectURL string, clientID string, clientSecret string, issuer string) { @@ -54,7 +49,7 @@ func initOIDC(redirectURL string, clientID string, clientSecret string, issuer s } func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { - var oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(user_signin_length) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -67,14 +62,16 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return } - + logger.Log(3, "using state string:", oauth_state_string) var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { - var content, err = getOIDCUserInfo(r.FormValue("state"), r.FormValue("code")) + var rState, rCode = getStateAndCode(r) + + var content, err = getOIDCUserInfo(rState, rCode) if err != nil { logger.Log(1, "error when getting user info from callback:", err.Error()) http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) @@ -98,7 +95,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { var jwt, jwtErr = logic.VerifyAuthRequest(authRequest) if jwtErr != nil { - logger.Log(1, "could not parse jwt for user", authRequest.UserName) + logger.Log(1, "could not parse jwt for user", authRequest.UserName, jwtErr.Error()) return } @@ -106,10 +103,12 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) } -func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) { +func getOIDCUserInfo(state string, code string) (u *OAuthUser, e error) { oauth_state_string, isValid := logic.IsStateValid(state) - if !isValid || state != oauth_state_string { - return nil, fmt.Errorf("invalid OAuth state") + logger.Log(3, "using oauth state string:,", oauth_state_string) + logger.Log(3, " state string:,", state) + if (!isValid || state != oauth_state_string) && !isStateCached(state) { + return nil, fmt.Errorf("invalid oauth state") } defer func() { @@ -136,7 +135,7 @@ func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) { return nil, fmt.Errorf("failed to verify raw id_token: \"%s\"", err.Error()) } - u = &OIDCUser{} + u = &OAuthUser{} if err := idToken.Claims(u); err != nil { e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error()) } diff --git a/auth/templates.go b/auth/templates.go new file mode 100644 index 00000000..65f5d124 --- /dev/null +++ b/auth/templates.go @@ -0,0 +1,81 @@ +package auth + +import "html/template" + +type ssoCallbackTemplateConfig struct { + User string + Verb string +} + +var ssoCallbackTemplate = template.Must( + template.New("ssocallback").Parse(` + + + + + + + Netmaker + + + +
+
+ + Netmaker + +
+
+
+

{{.User}} has been successfully {{.Verb}}

+
+

You may now close this window.

+
+
+
+ + `), +) + +var ssoErrCallbackTemplate = template.Must( + template.New("ssocallback").Parse(` + + + + + + + Netmaker + + + +
+
+ + Netmaker + +
+
+
+

{{.User}} unable to join network: {{.Verb}}

+
+

If you feel this is a mistake, please contact your network administrator.

+
+
+
+ + `), +) diff --git a/compose/docker-compose.yml b/compose/docker-compose.yml index a78aed34..45c4a8ea 100644 --- a/compose/docker-compose.yml +++ b/compose/docker-compose.yml @@ -39,6 +39,7 @@ services: VERBOSITY: "1" MANAGE_IPTABLES: "on" PORT_FORWARD_SERVICES: "dns" + METRICS_EXPORTER: "on" ports: - "51821-51830:51821-51830/udp" expose: @@ -111,6 +112,7 @@ services: restart: unless-stopped volumes: - /root/mosquitto.conf:/mosquitto/config/mosquitto.conf + - /root/mosquitto.passwords:/etc/mosquitto.passwords - mosquitto_data:/mosquitto/data - mosquitto_logs:/mosquitto/log - shared_certs:/mosquitto/certs @@ -123,6 +125,66 @@ services: - traefik.tcp.services.mqtts-svc.loadbalancer.server.port=8883 - traefik.tcp.routers.mqtts.service=mqtts-svc - traefik.tcp.routers.mqtts.entrypoints=websecure + prometheus: + container_name: prometheus + image: gravitl/netmaker-prometheus:latest + environment: + NETMAKER_METRICS_TARGET: "netmaker-exporter.NETMAKER_BASE_DOMAIN" + labels: + - traefik.enable=true + - traefik.http.routers.prometheus.entrypoints=websecure + - traefik.http.routers.prometheus.rule=Host(`prometheus.NETMAKER_BASE_DOMAIN`) + - traefik.http.services.prometheus.loadbalancer.server.port=9090 + - traefik.http.routers.prometheus.service=prometheus + restart: always + volumes: + - prometheus_data:/prometheus + depends_on: + - netmaker + ports: + - 9090:9090 + grafana: + container_name: grafana + image: gravitl/netmaker-grafana:latest + labels: + - traefik.enable=true + - traefik.http.routers.grafana.entrypoints=websecure + - traefik.http.routers.grafana.rule=Host(`grafana.NETMAKER_BASE_DOMAIN`) + - traefik.http.services.grafana.loadbalancer.server.port=3000 + - traefik.http.routers.grafana.service=grafana + environment: + PROMETHEUS_HOST: "prometheus.NETMAKER_BASE_DOMAIN" + NETMAKER_METRICS_TARGET: "netmaker-exporter.NETMAKER_BASE_DOMAIN" + ports: + - 3000:3000 + restart: always + links: + - prometheus + depends_on: + - prometheus + - netmaker + netmaker-exporter: + container_name: netmaker-exporter + image: gravitl/netmaker-exporter:latest + labels: + - traefik.enable=true + - traefik.http.routers.netmaker-exporter.entrypoints=websecure + - traefik.http.routers.netmaker-exporter.rule=Host(`netmaker-exporter.NETMAKER_BASE_DOMAIN`) + - traefik.http.services.netmaker-exporter.loadbalancer.server.port=8085 + - traefik.http.routers.netmaker-exporter.service=netmaker-exporter + restart: always + depends_on: + - netmaker + environment: + MQ_HOST: "mq" + MQ_PORT: "443" + MQ_SERVER_PORT: "1884" + PROMETHEUS: "on" + VERBOSITY: "1" + API_PORT: "8085" + PROMETHEUS_HOST: https://prometheus.NETMAKER_BASE_DOMAIN + expose: + - "8085" volumes: traefik_certs: {} shared_certs: {} @@ -130,3 +192,4 @@ volumes: dnsconfig: {} mosquitto_data: {} mosquitto_logs: {} + prometheus_data: {} diff --git a/config/config.go b/config/config.go index 71dd0e44..973c1d5d 100644 --- a/config/config.go +++ b/config/config.go @@ -70,6 +70,10 @@ type ServerConfig struct { MQServerPort string `yaml:"mqserverport"` Server string `yaml:"server"` PublicIPService string `yaml:"publicipservice"` + MetricsExporter string `yaml:"metrics_exporter"` + BasicAuth string `yaml:"basic_auth"` + LicenseValue string `yaml:"license_value"` + NetmakerAccountID string `yaml:"netmaker_account_id"` } // SQLConfig - Generic SQL Config diff --git a/controllers/controller.go b/controllers/controller.go index 7e3ea624..f43897c3 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -25,6 +25,10 @@ var HttpHandlers = []interface{}{ serverHandlers, extClientHandlers, ipHandlers, + metricHandlers, + loggerHandlers, + userGroupsHandlers, + networkUsersHandlers, } // HandleRESTRequests - handles the rest requests diff --git a/controllers/ext_client.go b/controllers/ext_client.go index 767a2ecb..05cdad33 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -12,7 +12,9 @@ import ( "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/mq" "github.com/skip2/go-qrcode" ) @@ -22,10 +24,10 @@ func extClientHandlers(r *mux.Router) { r.HandleFunc("/api/extclients", securityCheck(false, http.HandlerFunc(getAllExtClients))).Methods("GET") r.HandleFunc("/api/extclients/{network}", securityCheck(false, http.HandlerFunc(getNetworkExtClients))).Methods("GET") r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(getExtClient))).Methods("GET") - r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", securityCheck(false, http.HandlerFunc(getExtClientConf))).Methods("GET") - r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(updateExtClient))).Methods("PUT") - r.HandleFunc("/api/extclients/{network}/{clientid}", securityCheck(false, http.HandlerFunc(deleteExtClient))).Methods("DELETE") - r.HandleFunc("/api/extclients/{network}/{nodeid}", securityCheck(false, http.HandlerFunc(createExtClient))).Methods("POST") + r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", netUserSecurityCheck(false, true, http.HandlerFunc(getExtClientConf))).Methods("GET") + r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(updateExtClient))).Methods("PUT") + r.HandleFunc("/api/extclients/{network}/{clientid}", netUserSecurityCheck(false, true, http.HandlerFunc(deleteExtClient))).Methods("DELETE") + r.HandleFunc("/api/extclients/{network}/{nodeid}", netUserSecurityCheck(false, true, checkFreeTierLimits(clients_l, http.HandlerFunc(createExtClient)))).Methods("POST") } func checkIngressExists(nodeID string) bool { @@ -337,6 +339,8 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if err == nil { // check if parent network default ACL is enabled (yes) or not (no) extclient.Enabled = parentNetwork.DefaultACL == "yes" } + // check pro settings + err = logic.CreateExtClient(&extclient) if err != nil { logger.Log(0, r.Header.Get("user"), @@ -344,6 +348,27 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + + var isAdmin bool + if r.Header.Get("ismaster") != "yes" { + userID := r.Header.Get("user") + if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil { + logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access") + logic.DeleteExtClient(networkName, extclient.ClientID) + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + if !isAdmin { + if err = pro.AssociateNetworkUserClient(userID, networkName, extclient.ClientID); err != nil { + logger.Log(0, "failed to associate client", extclient.ClientID, "to user", userID) + } + extclient.OwnerID = userID + if _, err := logic.UpdateExtClient(extclient.ClientID, extclient.Network, extclient.Enabled, &extclient); err != nil { + logger.Log(0, "failed to add owner id", userID, "to client", extclient.ClientID) + } + } + } + logger.Log(0, r.Header.Get("user"), "created new ext client on network", networkName) w.WriteHeader(http.StatusOK) err = mq.PublishExtPeerUpdate(&node) @@ -402,7 +427,31 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + + // == PRO == + networkName := params["network"] + var changedID = newExtClient.ClientID != oldExtClient.ClientID + if r.Header.Get("ismaster") != "yes" { + userID := r.Header.Get("user") + _, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName) + if !doesOwn { + returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) + return + } + } + + if changedID && oldExtClient.OwnerID != "" { + if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, networkName, oldExtClient.ClientID); err != nil { + logger.Log(0, "failed to dissociate client", oldExtClient.ClientID, "from user", oldExtClient.OwnerID) + } + if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, networkName, newExtClient.ClientID); err != nil { + logger.Log(0, "failed to associate client", newExtClient.ClientID, "to user", oldExtClient.OwnerID) + } + } + // == END PRO == + var changedEnabled = newExtClient.Enabled != oldExtClient.Enabled // indicates there was a change in enablement + newclient, err := logic.UpdateExtClient(newExtClient.ClientID, params["network"], newExtClient.Enabled, &oldExtClient) if err != nil { logger.Log(0, r.Header.Get("user"), @@ -459,6 +508,24 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { return } + // == PRO == + if r.Header.Get("ismaster") != "yes" { + userID, clientID, networkName := r.Header.Get("user"), params["clientid"], params["network"] + _, doesOwn := doesUserOwnClient(userID, clientID, networkName) + if !doesOwn { + returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "internal")) + return + } + } + + if extclient.OwnerID != "" { + if err = pro.DissociateNetworkUserClient(extclient.OwnerID, extclient.Network, extclient.ClientID); err != nil { + logger.Log(0, "failed to dissociate client", extclient.ClientID, "from user", extclient.OwnerID) + } + } + + // == END PRO == + err = logic.DeleteExtClient(params["network"], params["clientid"]) if err != nil { logger.Log(0, r.Header.Get("user"), @@ -472,7 +539,65 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { if err != nil { logger.Log(1, "error setting ext peers on "+ingressnode.ID+": "+err.Error()) } + logger.Log(0, r.Header.Get("user"), "Deleted extclient client", params["clientid"], "from network", params["network"]) returnSuccessResponse(w, r, params["clientid"]+" deleted.") } + +func checkProClientAccess(username, clientID string, network *models.Network) (bool, error) { + u, err := logic.GetUser(username) + if err != nil { + return false, err + } + if u.IsAdmin { + return true, nil + } + + netUser, err := pro.GetNetworkUser(network.NetID, promodels.NetworkUserID(u.UserName)) + if err != nil { + return false, err + } + + if netUser.AccessLevel == pro.NET_ADMIN { + return false, nil + } + + if netUser.AccessLevel == pro.NO_ACCESS { + return false, fmt.Errorf("user does not have access") + } + + if !(len(netUser.Clients) < netUser.ClientLimit) { + return false, fmt.Errorf("user can not create more clients") + } + + if netUser.AccessLevel < pro.NO_ACCESS { + netUser.Clients = append(netUser.Clients, clientID) + if err = pro.UpdateNetworkUser(network.NetID, netUser); err != nil { + return false, err + } + } + return false, nil +} + +// checks if net user owns an ext client or is an admin +func doesUserOwnClient(username, clientID, network string) (bool, bool) { + u, err := logic.GetUser(username) + if err != nil { + return false, false + } + if u.IsAdmin { + return true, true + } + + netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(u.UserName)) + if err != nil { + return false, false + } + + if netUser.AccessLevel == pro.NET_ADMIN { + return false, true + } + + return false, logic.StringSliceContains(netUser.Clients, clientID) +} diff --git a/controllers/limits.go b/controllers/limits.go new file mode 100644 index 00000000..ddbff298 --- /dev/null +++ b/controllers/limits.go @@ -0,0 +1,59 @@ +package controller + +import ( + "net/http" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/ee" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" +) + +// limit consts +const ( + node_l = 0 + networks_l = 1 + users_l = 2 + clients_l = 3 +) + +func checkFreeTierLimits(limit_choice int, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var errorResponse = models.ErrorResponse{ + Code: http.StatusUnauthorized, Message: "free tier limits exceeded on networks", + } + + if ee.Limits.FreeTier { // check that free tier limits not exceeded + if limit_choice == networks_l { + currentNetworks, err := logic.GetNetworks() + if (err != nil && !database.IsEmptyRecord(err)) || len(currentNetworks) >= ee.Limits.Networks { + returnErrorResponse(w, r, errorResponse) + return + } + } else if limit_choice == node_l { + nodes, err := logic.GetAllNodes() + if (err != nil && !database.IsEmptyRecord(err)) || len(nodes) >= ee.Limits.Nodes { + errorResponse.Message = "free tier limits exceeded on nodes" + returnErrorResponse(w, r, errorResponse) + return + } + } else if limit_choice == users_l { + users, err := logic.GetUsers() + if (err != nil && !database.IsEmptyRecord(err)) || len(users) >= ee.Limits.Users { + errorResponse.Message = "free tier limits exceeded on users" + returnErrorResponse(w, r, errorResponse) + return + } + } else if limit_choice == clients_l { + clients, err := logic.GetAllExtClients() + if (err != nil && !database.IsEmptyRecord(err)) || len(clients) >= ee.Limits.Clients { + errorResponse.Message = "free tier limits exceeded on external clients" + returnErrorResponse(w, r, errorResponse) + return + } + } + } + + next.ServeHTTP(w, r) + } +} diff --git a/controllers/logger.go b/controllers/logger.go new file mode 100644 index 00000000..316783fc --- /dev/null +++ b/controllers/logger.go @@ -0,0 +1,22 @@ +package controller + +import ( + "fmt" + "net/http" + "time" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" +) + +func loggerHandlers(r *mux.Router) { + r.HandleFunc("/api/logs", securityCheck(true, http.HandlerFunc(getLogs))).Methods("GET") +} + +func getLogs(w http.ResponseWriter, r *http.Request) { + var currentTime = time.Now().Format(logger.TimeFormatDay) + var currentFilePath = fmt.Sprintf("data/netmaker.log.%s", currentTime) + logger.DumpFile(currentFilePath) + w.WriteHeader(http.StatusOK) + w.Write([]byte(logger.Retrieve(currentFilePath))) +} diff --git a/controllers/metrics.go b/controllers/metrics.go new file mode 100644 index 00000000..1c08350f --- /dev/null +++ b/controllers/metrics.go @@ -0,0 +1,102 @@ +package controller + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" +) + +func metricHandlers(r *mux.Router) { + r.HandleFunc("/api/metrics/{network}/{nodeid}", securityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods("GET") + r.HandleFunc("/api/metrics/{network}", securityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods("GET") + r.HandleFunc("/api/metrics", securityCheck(true, http.HandlerFunc(getAllMetrics))).Methods("GET") +} + +// get the metrics of a given node +func getNodeMetrics(w http.ResponseWriter, r *http.Request) { + // set header. + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + nodeID := params["nodeid"] + + logger.Log(1, r.Header.Get("user"), "requested fetching metrics for node", nodeID, "on network", params["network"]) + metrics, err := logic.GetMetrics(nodeID) + if err != nil { + logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of node", nodeID, err.Error()) + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + logger.Log(1, r.Header.Get("user"), "fetched metrics for node", params["nodeid"]) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(metrics) +} + +// get the metrics of all nodes in given network +func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) { + // set header. + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + network := params["network"] + + logger.Log(1, r.Header.Get("user"), "requested fetching network node metrics on network", network) + networkNodes, err := logic.GetNetworkNodes(network) + if err != nil { + logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes in network", network, err.Error()) + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + networkMetrics := models.NetworkMetrics{} + networkMetrics.Nodes = make(models.MetricsMap) + + for i := range networkNodes { + id := networkNodes[i].ID + metrics, err := logic.GetMetrics(id) + if err != nil { + logger.Log(1, r.Header.Get("user"), "failed to append metrics of node", id, "during network metrics fetch", err.Error()) + continue + } + networkMetrics.Nodes[id] = *metrics + } + + logger.Log(1, r.Header.Get("user"), "fetched metrics for network", network) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(networkMetrics) +} + +// get Metrics of all nodes on server, lots of data +func getAllMetrics(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + logger.Log(1, r.Header.Get("user"), "requested fetching all metrics") + + allNodes, err := logic.GetAllNodes() + if err != nil { + logger.Log(1, r.Header.Get("user"), "failed to fetch metrics of all nodes on server", err.Error()) + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + networkMetrics := models.NetworkMetrics{} + networkMetrics.Nodes = make(models.MetricsMap) + + for i := range allNodes { + id := allNodes[i].ID + metrics, err := logic.GetMetrics(id) + if err != nil { + logger.Log(1, r.Header.Get("user"), "failed to append metrics of node", id, "during all nodes metrics fetch", err.Error()) + continue + } + networkMetrics.Nodes[id] = *metrics + } + + logger.Log(1, r.Header.Get("user"), "fetched metrics for all nodes on server") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(networkMetrics) +} diff --git a/controllers/network.go b/controllers/network.go index b94a230c..7c9c71f1 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -25,7 +25,7 @@ const NO_NETWORKS_PRESENT = "THIS_USER_HAS_NONE" func networkHandlers(r *mux.Router) { r.HandleFunc("/api/networks", securityCheck(false, http.HandlerFunc(getNetworks))).Methods("GET") - r.HandleFunc("/api/networks", securityCheck(true, http.HandlerFunc(createNetwork))).Methods("POST") + r.HandleFunc("/api/networks", securityCheck(true, checkFreeTierLimits(networks_l, http.HandlerFunc(createNetwork)))).Methods("POST") r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(getNetwork))).Methods("GET") r.HandleFunc("/api/networks/{networkname}", securityCheck(false, http.HandlerFunc(updateNetwork))).Methods("PUT") r.HandleFunc("/api/networks/{networkname}/nodelimit", securityCheck(true, http.HandlerFunc(updateNetworkNodeLimit))).Methods("PUT") @@ -199,7 +199,7 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { newNetwork.DefaultPostUp = network.DefaultPostUp } - rangeupdate4, rangeupdate6, localrangeupdate, holepunchupdate, err := logic.UpdateNetwork(&network, &newNetwork) + rangeupdate4, rangeupdate6, localrangeupdate, holepunchupdate, groupsDelta, userDelta, err := logic.UpdateNetwork(&network, &newNetwork) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to update network: ", err.Error()) @@ -207,6 +207,24 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) { return } + if len(groupsDelta) > 0 { + for _, g := range groupsDelta { + users, err := logic.GetGroupUsers(g) + if err == nil { + for _, user := range users { + logic.AdjustNetworkUserPermissions(&user, &newNetwork) + } + } + } + } + if len(userDelta) > 0 { + for _, uname := range userDelta { + user, err := logic.GetReturnUser(uname) + if err == nil { + logic.AdjustNetworkUserPermissions(&user, &newNetwork) + } + } + } if rangeupdate4 { err = logic.UpdateNetworkNodeAddresses(network.NetID) if err != nil { @@ -536,6 +554,15 @@ func createAccessKey(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "badrequest")) return } + + // do not allow access key creations view API with user names + if _, err = logic.GetUser(key.Name); err == nil { + logger.Log(0, "access key creation with invalid name attempted by", r.Header.Get("user")) + returnErrorResponse(w, r, formatError(fmt.Errorf("cannot create access key with user name"), "badrequest")) + logic.DeleteKey(key.Name, network.NetID) + return + } + logger.Log(1, r.Header.Get("user"), "created access key", accesskey.Name, "on", netname) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(key) diff --git a/controllers/network_test.go b/controllers/network_test.go index fee91b9f..a85b03e3 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -17,7 +17,7 @@ type NetworkValidationTestCase struct { } func TestCreateNetwork(t *testing.T) { - database.InitializeDatabase() + initialize() deleteAllNetworks() var network models.Network @@ -30,7 +30,7 @@ func TestCreateNetwork(t *testing.T) { assert.Nil(t, err) } func TestGetNetwork(t *testing.T) { - database.InitializeDatabase() + initialize() createNet() t.Run("GetExistingNetwork", func(t *testing.T) { @@ -46,7 +46,7 @@ func TestGetNetwork(t *testing.T) { } func TestDeleteNetwork(t *testing.T) { - database.InitializeDatabase() + initialize() createNet() //create nodes t.Run("NetworkwithNodes", func(t *testing.T) { @@ -62,7 +62,7 @@ func TestDeleteNetwork(t *testing.T) { } func TestCreateKey(t *testing.T) { - database.InitializeDatabase() + initialize() createNet() keys, _ := logic.GetKeys("skynet") for _, key := range keys { @@ -74,7 +74,7 @@ func TestCreateKey(t *testing.T) { t.Run("NameTooLong", func(t *testing.T) { network, err := logic.GetNetwork("skynet") assert.Nil(t, err) - accesskey.Name = "Thisisareallylongkeynamethatwillfail" + accesskey.Name = "ThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfailThisisareallylongkeynamethatwillfail" _, err = logic.CreateAccessKey(accesskey, network) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag") @@ -134,7 +134,7 @@ func TestCreateKey(t *testing.T) { } func TestGetKeys(t *testing.T) { - database.InitializeDatabase() + initialize() deleteAllNetworks() createNet() network, err := logic.GetNetwork("skynet") @@ -157,7 +157,7 @@ func TestGetKeys(t *testing.T) { }) } func TestDeleteKey(t *testing.T) { - database.InitializeDatabase() + initialize() createNet() network, err := logic.GetNetwork("skynet") assert.Nil(t, err) @@ -179,7 +179,7 @@ func TestDeleteKey(t *testing.T) { func TestSecurityCheck(t *testing.T) { //these seem to work but not sure it the tests are really testing the functionality - database.InitializeDatabase() + initialize() os.Setenv("MASTER_KEY", "secretkey") t.Run("NoNetwork", func(t *testing.T) { networks, username, err := SecurityCheck(false, "", "Bearer secretkey") @@ -210,7 +210,7 @@ func TestValidateNetwork(t *testing.T) { //t.Skip() //This functions is not called by anyone //it panics as validation function 'display_name_valid' is not defined - database.InitializeDatabase() + initialize() //yes := true //no := false //deleteNet(t) @@ -295,7 +295,7 @@ func TestValidateNetwork(t *testing.T) { func TestIpv6Network(t *testing.T) { //these seem to work but not sure it the tests are really testing the functionality - database.InitializeDatabase() + initialize() os.Setenv("MASTER_KEY", "secretkey") deleteAllNetworks() createNet() @@ -321,6 +321,21 @@ func deleteAllNetworks() { } } +func initialize() { + database.InitializeDatabase() + createAdminUser() +} + +func createAdminUser() { + logic.CreateAdmin(models.User{ + UserName: "admin", + Password: "password", + IsAdmin: true, + Networks: []string{}, + Groups: []string{}, + }) +} + func createNet() { var network models.Network network.NetID = "skynet" diff --git a/controllers/networkusers.go b/controllers/networkusers.go new file mode 100644 index 00000000..1e6c37d4 --- /dev/null +++ b/controllers/networkusers.go @@ -0,0 +1,359 @@ +package controller + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" +) + +func networkUsersHandlers(r *mux.Router) { + r.HandleFunc("/api/networkusers", securityCheck(true, http.HandlerFunc(getAllNetworkUsers))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(getNetworkUsers))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(getNetworkUser))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(createNetworkUser))).Methods("POST") + r.HandleFunc("/api/networkusers/{network}", securityCheck(true, http.HandlerFunc(updateNetworkUser))).Methods("PUT") + r.HandleFunc("/api/networkusers/data/{networkuser}/me", netUserSecurityCheck(false, false, http.HandlerFunc(getNetworkUserData))).Methods("GET") + r.HandleFunc("/api/networkusers/{network}/{networkuser}", securityCheck(true, http.HandlerFunc(deleteNetworkUser))).Methods("DELETE") +} + +// == RETURN TYPES == + +// NetworkName - represents a network name/ID +type NetworkName string + +// NetworkUserDataMap - map of all data per network for a user +type NetworkUserDataMap map[NetworkName]NetworkUserData + +// NetworkUserData - data struct for network users +type NetworkUserData struct { + Nodes []models.Node `json:"nodes" bson:"nodes" yaml:"nodes"` + Clients []models.ExtClient `json:"clients" bson:"clients" yaml:"clients"` + Vpn []models.Node `json:"vpns" bson:"vpns" yaml:"vpns"` + Networks []models.Network `json:"networks" bson:"networks" yaml:"networks"` + User promodels.NetworkUser `json:"user" bson:"user" yaml:"user"` +} + +// == END RETURN TYPES == + +// returns a map of a network user's data across all networks +func getNetworkUserData(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + networkUserName := params["networkuser"] + logger.Log(1, r.Header.Get("user"), "requested fetching network user data for user", networkUserName) + + networks, err := logic.GetNetworks() + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + if networkUserName == "" { + returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) + return + } + + u, err := logic.GetUser(networkUserName) + if err != nil { + returnErrorResponse(w, r, formatError(errors.New("could not find user"), "badrequest")) + return + } + + // initialize the return data of network users + returnData := make(NetworkUserDataMap) + + // go through each network and get that user's data + // if user has no access, give no data + // if user is a net admin, give all nodes + // if user has node access, give user's nodes if any + // if user has client access, git user's clients if any + for i := range networks { + + netID := networks[i].NetID + newData := NetworkUserData{ + Nodes: []models.Node{}, + Clients: []models.ExtClient{}, + Vpn: []models.Node{}, + Networks: []models.Network{}, + } + netUser, err := pro.GetNetworkUser(netID, promodels.NetworkUserID(networkUserName)) + // check if user has access + if err == nil && netUser.AccessLevel != pro.NO_ACCESS { + newData.User = promodels.NetworkUser{ + AccessLevel: netUser.AccessLevel, + ClientLimit: netUser.ClientLimit, + NodeLimit: netUser.NodeLimit, + Nodes: netUser.Nodes, + Clients: netUser.Clients, + } + // check network level permissions + if doesNetworkAllow := pro.IsUserAllowed(&networks[i], networkUserName, u.Groups); doesNetworkAllow { + netNodes, err := logic.GetNetworkNodes(netID) + if err != nil { + logger.Log(0, "failed to retrieve nodes on network", netID, "for user", string(netUser.ID)) + } else { + if netUser.AccessLevel <= pro.NODE_ACCESS { // handle nodes + // if access level is NODE_ACCESS, filter nodes + if netUser.AccessLevel == pro.NODE_ACCESS { + for i := range netNodes { + if logic.StringSliceContains(netUser.Nodes, netNodes[i].ID) { + newData.Nodes = append(newData.Nodes, netNodes[i]) + } + } + } else { // net admin so, get all nodes and ext clients on network... + newData.Nodes = netNodes + for i := range netNodes { + if netNodes[i].IsIngressGateway == "yes" { + newData.Vpn = append(newData.Vpn, netNodes[i]) + if clients, err := logic.GetExtClientsByID(netNodes[i].ID, netID); err == nil { + newData.Clients = append(newData.Clients, clients...) + } + } + } + newData.Networks = append(newData.Networks, networks[i]) + } + } + if netUser.AccessLevel <= pro.CLIENT_ACCESS && netUser.AccessLevel != pro.NET_ADMIN { + for _, c := range netUser.Clients { + if client, err := logic.GetExtClient(c, netID); err == nil { + newData.Clients = append(newData.Clients, client) + } + } + for i := range netNodes { + if netNodes[i].IsIngressGateway == "yes" { + newData.Vpn = append(newData.Vpn, netNodes[i]) + } + } + } + } + } + returnData[NetworkName(netID)] = newData + } + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(returnData) +} + +// returns a map of all network users mapped to each network +func getAllNetworkUsers(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + logger.Log(1, r.Header.Get("user"), "requested fetching all network users") + type allNetworkUsers = map[string][]promodels.NetworkUser + + networks, err := logic.GetNetworks() + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + var allNetUsers = make(allNetworkUsers, len(networks)) + + for i := range networks { + netusers, err := pro.GetNetworkUsers(networks[i].NetID) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + for _, v := range netusers { + allNetUsers[networks[i].NetID] = append(allNetUsers[networks[i].NetID], v) + } + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(allNetUsers) +} + +func getNetworkUsers(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + netname := params["network"] + logger.Log(1, r.Header.Get("user"), "requested fetching network users for network", netname) + + _, err := logic.GetNetwork(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + netusers, err := pro.GetNetworkUsers(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(netusers) +} + +func getNetworkUser(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + netname := params["network"] + logger.Log(1, r.Header.Get("user"), "requested fetching network user", params["networkuser"], "on network", netname) + + _, err := logic.GetNetwork(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + netuserToGet := params["networkuser"] + if netuserToGet == "" { + returnErrorResponse(w, r, formatError(errors.New("netuserToGet"), "badrequest")) + return + } + + netuser, err := pro.GetNetworkUser(netname, promodels.NetworkUserID(netuserToGet)) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(netuser) +} + +func createNetworkUser(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var params = mux.Vars(r) + netname := params["network"] + logger.Log(1, r.Header.Get("user"), "requested creating a network user on network", netname) + + network, err := logic.GetNetwork(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + var networkuser promodels.NetworkUser + + // we decode our body request params + err = json.NewDecoder(r.Body).Decode(&networkuser) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + err = pro.CreateNetworkUser(&network, &networkuser) + if err != nil { + returnErrorResponse(w, r, formatError(err, "badrequest")) + return + } + + w.WriteHeader(http.StatusOK) +} + +func updateNetworkUser(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var params = mux.Vars(r) + netname := params["network"] + logger.Log(1, r.Header.Get("user"), "requested updating a network user on network", netname) + + network, err := logic.GetNetwork(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + var networkuser promodels.NetworkUser + + // we decode our body request params + err = json.NewDecoder(r.Body).Decode(&networkuser) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + if networkuser.ID == "" || !pro.DoesNetworkUserExist(netname, networkuser.ID) { + returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) + return + } + if networkuser.AccessLevel < pro.NET_ADMIN || networkuser.AccessLevel > pro.NO_ACCESS { + returnErrorResponse(w, r, formatError(errors.New("invalid user access level provided"), "badrequest")) + return + } + + if networkuser.ClientLimit < 0 || networkuser.NodeLimit < 0 { + returnErrorResponse(w, r, formatError(errors.New("negative user limit provided"), "badrequest")) + return + } + + u, err := logic.GetUser(string(networkuser.ID)) + if err != nil { + returnErrorResponse(w, r, formatError(errors.New("invalid user "+string(networkuser.ID)), "badrequest")) + return + } + + if !pro.IsUserAllowed(&network, u.UserName, u.Groups) { + returnErrorResponse(w, r, formatError(errors.New("user must be in allowed groups or users"), "badrequest")) + return + } + + if networkuser.AccessLevel == pro.NET_ADMIN { + currentUser, err := logic.GetUser(string(networkuser.ID)) + if err != nil { + returnErrorResponse(w, r, formatError(errors.New("user model not found for "+string(networkuser.ID)), "badrequest")) + return + } + + if !logic.StringSliceContains(currentUser.Networks, netname) { + // append network name to user model to conform to old model + if err = logic.UpdateUserNetworks( + append(currentUser.Networks, netname), + currentUser.Groups, + currentUser.IsAdmin, + &models.ReturnUser{ + Groups: currentUser.Groups, + IsAdmin: currentUser.IsAdmin, + Networks: currentUser.Networks, + UserName: currentUser.UserName, + }, + ); err != nil { + returnErrorResponse(w, r, formatError(errors.New("user model failed net admin update "+string(networkuser.ID)+" (are they an admin?"), "badrequest")) + return + } + } + } + + err = pro.UpdateNetworkUser(netname, &networkuser) + if err != nil { + returnErrorResponse(w, r, formatError(err, "badrequest")) + return + } + + w.WriteHeader(http.StatusOK) +} + +func deleteNetworkUser(w http.ResponseWriter, r *http.Request) { + + var params = mux.Vars(r) + netname := params["network"] + + logger.Log(1, r.Header.Get("user"), "requested deleting network user", params["networkuser"], "on network", netname) + + _, err := logic.GetNetwork(netname) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + netuserToDelete := params["networkuser"] + if netuserToDelete == "" { + returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + return + } + + if err := pro.DeleteNetworkUser(netname, netuserToDelete); err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/controllers/node.go b/controllers/node.go index 146f5681..14075dbe 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -11,7 +11,9 @@ import ( "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" @@ -31,7 +33,7 @@ func nodeHandlers(r *mux.Router) { r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", securityCheck(false, http.HandlerFunc(createIngressGateway))).Methods("POST") r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", securityCheck(false, http.HandlerFunc(deleteIngressGateway))).Methods("DELETE") r.HandleFunc("/api/nodes/{network}/{nodeid}/approve", authorize(false, true, "user", http.HandlerFunc(uncordonNode))).Methods("POST") - r.HandleFunc("/api/nodes/{network}", nodeauth(http.HandlerFunc(createNode))).Methods("POST") + r.HandleFunc("/api/nodes/{network}", nodeauth(checkFreeTierLimits(node_l, http.HandlerFunc(createNode)))).Methods("POST") r.HandleFunc("/api/nodes/adm/{network}/lastmodified", authorize(false, true, "network", http.HandlerFunc(getLastModified))).Methods("GET") r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods("POST") } @@ -237,6 +239,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha returnErrorResponse(w, r, errorResponse) return } + isnetadmin := isadmin if errN == nil && isadmin { nodeID = "mastermac" @@ -244,7 +247,7 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha r.Header.Set("ismasterkey", "yes") } if !isadmin && params["network"] != "" { - if logic.StringSliceContains(networks, params["network"]) { + if logic.StringSliceContains(networks, params["network"]) && pro.IsUserNetAdmin(params["network"], username) { isnetadmin = true } } @@ -435,6 +438,7 @@ func getNode(w http.ResponseWriter, r *http.Request) { Node: node, Peers: peerUpdate.Peers, ServerConfig: servercfg.GetServerInfo(), + PeerIDs: peerUpdate.PeerIDs, } logger.Log(2, r.Header.Get("user"), "fetched node", params["nodeid"]) @@ -537,7 +541,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } - validKey := logic.IsKeyValid(networkName, node.AccessKey) + keyName, validKey := logic.IsKeyValid(networkName, node.AccessKey) if !validKey { // Check to see if network will allow manual sign up // may want to switch this up with the valid key check and avoid a DB call that way. @@ -554,6 +558,14 @@ func createNode(w http.ResponseWriter, r *http.Request) { return } } + user, err := pro.GetNetworkUser(networkName, promodels.NetworkUserID(keyName)) + if err == nil { + if user.ID != "" { + logger.Log(1, "associating new node with user", keyName) + node.OwnerID = string(user.ID) + } + } + key, keyErr := logic.RetrievePublicTrafficKey() if keyErr != nil { logger.Log(0, "error retrieving key: ", keyErr.Error()) @@ -584,6 +596,24 @@ func createNode(w http.ResponseWriter, r *http.Request) { return } + // check if key belongs to a user + // if so add to their netuser data + // if it fails remove the node and fail request + if user != nil { + var updatedUserNode bool + user.Nodes = append(user.Nodes, node.ID) // add new node to user + if err = pro.UpdateNetworkUser(networkName, user); err == nil { + logger.Log(1, "added node", node.ID, node.Name, "to user", string(user.ID)) + updatedUserNode = true + } + if !updatedUserNode { // user was found but not updated, so delete node + logger.Log(0, "failed to add node to user", keyName) + logic.DeleteNodeByID(&node, true) + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + } + peerUpdate, err := logic.GetPeerUpdate(&node) if err != nil && !database.IsEmptyRecord(err) { logger.Log(0, r.Header.Get("user"), @@ -596,6 +626,7 @@ func createNode(w http.ResponseWriter, r *http.Request) { Node: node, Peers: peerUpdate.Peers, ServerConfig: servercfg.GetServerInfo(), + PeerIDs: peerUpdate.PeerIDs, } logger.Log(1, r.Header.Get("user"), "created new node", node.Name, "on network", node.Network) @@ -911,6 +942,13 @@ func deleteNode(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "badrequest")) return } + if r.Header.Get("ismaster") != "yes" { + username := r.Header.Get("user") + if username != "" && !doesUserOwnNode(username, params["network"], nodeid) { + returnErrorResponse(w, r, formatError(fmt.Errorf("user not permitted"), "badrequest")) + return + } + } //send update to node to be deleted before deleting on server otherwise message cannot be sent node.Action = models.NODE_DELETE @@ -919,6 +957,7 @@ func deleteNode(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + returnSuccessResponse(w, r, nodeid+" deleted.") logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"]) @@ -1006,3 +1045,24 @@ func updateRelay(oldnode, newnode *models.Node) { } logic.UpdateNode(relay, newrelay) } + +func doesUserOwnNode(username, network, nodeID string) bool { + u, err := logic.GetUser(username) + if err != nil { + return false + } + if u.IsAdmin { + return true + } + + netUser, err := pro.GetNetworkUser(network, promodels.NetworkUserID(u.UserName)) + if err != nil { + return false + } + + if netUser.AccessLevel == pro.NET_ADMIN { + return true + } + + return logic.StringSliceContains(netUser.Nodes, nodeID) +} diff --git a/controllers/security.go b/controllers/security.go index 5c447dda..f793da5d 100644 --- a/controllers/security.go +++ b/controllers/security.go @@ -9,7 +9,9 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/servercfg" ) @@ -58,6 +60,75 @@ func securityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc { } } +func netUserSecurityCheck(isNodes, isClients bool, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var errorResponse = models.ErrorResponse{ + Code: http.StatusUnauthorized, Message: "unauthorized", + } + r.Header.Set("ismaster", "no") + + var params = mux.Vars(r) + var netUserName = params["networkuser"] + var network = params["network"] + + bearerToken := r.Header.Get("Authorization") + + var tokenSplit = strings.Split(bearerToken, " ") + var authToken = "" + + if len(tokenSplit) < 2 { + returnErrorResponse(w, r, errorResponse) + return + } else { + authToken = tokenSplit[1] + } + + isMasterAuthenticated := authenticateMaster(authToken) + if isMasterAuthenticated { + r.Header.Set("user", "master token user") + r.Header.Set("ismaster", "yes") + next.ServeHTTP(w, r) + return + } + + userName, _, isadmin, err := logic.VerifyUserToken(authToken) + if err != nil { + returnErrorResponse(w, r, errorResponse) + return + } + r.Header.Set("user", userName) + + if isadmin { + next.ServeHTTP(w, r) + return + } + + if isNodes || isClients { + necessaryAccess := pro.NET_ADMIN + if isClients { + necessaryAccess = pro.CLIENT_ACCESS + } + if isNodes { + necessaryAccess = pro.NODE_ACCESS + } + u, err := pro.GetNetworkUser(network, promodels.NetworkUserID(userName)) + if err != nil { + returnErrorResponse(w, r, errorResponse) + return + } + if u.AccessLevel > necessaryAccess { + returnErrorResponse(w, r, errorResponse) + return + } + } else if netUserName != userName { + returnErrorResponse(w, r, errorResponse) + return + } + + next.ServeHTTP(w, r) + } +} + // SecurityCheck - checks token stuff func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, string, error) { var tokenSplit = strings.Split(token, " ") @@ -88,6 +159,9 @@ func SecurityCheck(reqAdmin bool, netname string, token string) ([]string, strin if len(netname) > 0 && (!authenticateNetworkUser(netname, userNetworks) || len(userNetworks) == 0) { return nil, username, unauthorized_err } + if !pro.IsUserNetAdmin(netname, username) { + return nil, "", unauthorized_err + } return userNetworks, username, nil } diff --git a/controllers/user.go b/controllers/user.go index eddc7cef..9c6dbb57 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -7,11 +7,17 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/gorilla/websocket" "github.com/gravitl/netmaker/auth" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" +) + +var ( + upgrader = websocket.Upgrader{} ) func userHandlers(r *mux.Router) { @@ -22,12 +28,14 @@ func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(updateUser)))).Methods("PUT") r.HandleFunc("/api/users/networks/{username}", securityCheck(true, http.HandlerFunc(updateUserNetworks))).Methods("PUT") r.HandleFunc("/api/users/{username}/adm", securityCheck(true, http.HandlerFunc(updateUserAdm))).Methods("PUT") - r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(createUser))).Methods("POST") + r.HandleFunc("/api/users/{username}", securityCheck(true, checkFreeTierLimits(users_l, http.HandlerFunc(createUser)))).Methods("POST") r.HandleFunc("/api/users/{username}", securityCheck(true, http.HandlerFunc(deleteUser))).Methods("DELETE") r.HandleFunc("/api/users/{username}", securityCheck(false, continueIfUserMatch(http.HandlerFunc(getUser)))).Methods("GET") r.HandleFunc("/api/users", securityCheck(true, http.HandlerFunc(getUsers))).Methods("GET") r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") + r.HandleFunc("/api/oauth/node-handler", socketHandler) + r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterNodeSSO).Methods("GET") } // swagger:route POST /api/users/adm/authenticate nodes authenticateUser @@ -50,6 +58,11 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { Code: http.StatusInternalServerError, Message: "W1R3: It's not you it's me.", } + if !servercfg.IsBasicAuthEnabled() { + returnErrorResponse(response, request, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) + return + } + decoder := json.NewDecoder(request.Body) decoderErr := decoder.Decode(&authRequest) defer request.Body.Close() @@ -216,14 +229,20 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "badrequest")) return } - admin, err = logic.CreateAdmin(admin) + if !servercfg.IsBasicAuthEnabled() { + returnErrorResponse(w, r, formatError(fmt.Errorf("basic auth is disabled"), "badrequest")) + return + } + + admin, err = logic.CreateAdmin(admin) if err != nil { logger.Log(0, admin.UserName, "failed to create admin: ", err.Error()) returnErrorResponse(w, r, formatError(err, "badrequest")) return } + logger.Log(1, admin.UserName, "was made a new admin") json.NewEncoder(w).Encode(admin) } @@ -250,6 +269,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "badrequest")) return } + user, err = logic.CreateUser(user) if err != nil { logger.Log(0, user.UserName, "error creating new user: ", @@ -294,7 +314,13 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "badrequest")) return } - err = logic.UpdateUserNetworks(userchange.Networks, userchange.IsAdmin, &user) + err = logic.UpdateUserNetworks(userchange.Networks, userchange.Groups, userchange.IsAdmin, &models.ReturnUser{ + Groups: user.Groups, + IsAdmin: user.IsAdmin, + Networks: user.Networks, + UserName: user.UserName, + }) + if err != nil { logger.Log(0, username, "failed to update user networks: ", err.Error()) @@ -444,3 +470,15 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { logger.Log(1, username, "was deleted") json.NewEncoder(w).Encode(params["username"] + " deleted.") } + +// Called when vpn client dials in to start the auth flow and first stage is to get register URL itself +func socketHandler(w http.ResponseWriter, r *http.Request) { + // Upgrade our raw HTTP connection to a websocket based one + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Log(0, "error during connection upgrade for node SSO sign-in:", err.Error()) + return + } + // Start handling the session + go auth.SessionHandler(conn) +} diff --git a/controllers/user_test.go b/controllers/user_test.go index 7d47dad7..71a919b1 100644 --- a/controllers/user_test.go +++ b/controllers/user_test.go @@ -31,7 +31,7 @@ func TestHasAdmin(t *testing.T) { assert.False(t, found) }) t.Run("No admin user", func(t *testing.T) { - var user = models.User{"noadmin", "password", nil, false} + var user = models.User{"noadmin", "password", nil, false, nil} _, err := logic.CreateUser(user) assert.Nil(t, err) found, err := logic.HasAdmin() @@ -39,7 +39,7 @@ func TestHasAdmin(t *testing.T) { assert.False(t, found) }) t.Run("admin user", func(t *testing.T) { - var user = models.User{"admin", "password", nil, true} + var user = models.User{"admin", "password", nil, true, nil} _, err := logic.CreateUser(user) assert.Nil(t, err) found, err := logic.HasAdmin() @@ -47,7 +47,7 @@ func TestHasAdmin(t *testing.T) { assert.True(t, found) }) t.Run("multiple admins", func(t *testing.T) { - var user = models.User{"admin1", "password", nil, true} + var user = models.User{"admin1", "password", nil, true, nil} _, err := logic.CreateUser(user) assert.Nil(t, err) found, err := logic.HasAdmin() @@ -59,7 +59,7 @@ func TestHasAdmin(t *testing.T) { func TestCreateUser(t *testing.T) { database.InitializeDatabase() deleteAllUsers() - user := models.User{"admin", "password", nil, true} + user := models.User{"admin", "password", nil, true, nil} t.Run("NoUser", func(t *testing.T) { admin, err := logic.CreateUser(user) assert.Nil(t, err) @@ -101,7 +101,7 @@ func TestDeleteUser(t *testing.T) { assert.False(t, deleted) }) t.Run("Existing User", func(t *testing.T) { - user := models.User{"admin", "password", nil, true} + user := models.User{"admin", "password", nil, true, nil} logic.CreateUser(user) deleted, err := logic.DeleteUser("admin") assert.Nil(t, err) @@ -166,7 +166,7 @@ func TestGetUser(t *testing.T) { assert.Equal(t, "", admin.UserName) }) t.Run("UserExisits", func(t *testing.T) { - user := models.User{"admin", "password", nil, true} + user := models.User{"admin", "password", nil, true, nil} logic.CreateUser(user) admin, err := logic.GetUser("admin") assert.Nil(t, err) @@ -183,7 +183,7 @@ func TestGetUserInternal(t *testing.T) { assert.Equal(t, "", admin.UserName) }) t.Run("UserExisits", func(t *testing.T) { - user := models.User{"admin", "password", nil, true} + user := models.User{"admin", "password", nil, true, nil} logic.CreateUser(user) admin, err := GetUserInternal("admin") assert.Nil(t, err) @@ -200,14 +200,14 @@ func TestGetUsers(t *testing.T) { assert.Equal(t, []models.ReturnUser(nil), admin) }) t.Run("UserExisits", func(t *testing.T) { - user := models.User{"admin", "password", nil, true} + user := models.User{"admin", "password", nil, true, nil} logic.CreateUser(user) admins, err := logic.GetUsers() assert.Nil(t, err) assert.Equal(t, user.UserName, admins[0].UserName) }) t.Run("MulipleUsers", func(t *testing.T) { - user := models.User{"user", "password", nil, true} + user := models.User{"user", "password", nil, true, nil} logic.CreateUser(user) admins, err := logic.GetUsers() assert.Nil(t, err) @@ -225,8 +225,8 @@ func TestGetUsers(t *testing.T) { func TestUpdateUser(t *testing.T) { database.InitializeDatabase() deleteAllUsers() - user := models.User{"admin", "password", nil, true} - newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true} + user := models.User{"admin", "password", nil, true, nil} + newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true, []string{}} t.Run("NonExistantUser", func(t *testing.T) { admin, err := logic.UpdateUser(newuser, user) assert.EqualError(t, err, "could not find any records") @@ -288,10 +288,10 @@ func TestVerifyAuthRequest(t *testing.T) { authRequest.Password = "password" jwt, err := logic.VerifyAuthRequest(authRequest) assert.Equal(t, "", jwt) - assert.EqualError(t, err, "incorrect credentials") + assert.EqualError(t, err, "error retrieving user from db: could not find any records") }) t.Run("Non-Admin", func(t *testing.T) { - user := models.User{"nonadmin", "somepass", nil, false} + user := models.User{"nonadmin", "somepass", nil, false, []string{}} logic.CreateUser(user) authRequest := models.UserAuthParams{"nonadmin", "somepass"} jwt, err := logic.VerifyAuthRequest(authRequest) @@ -299,7 +299,7 @@ func TestVerifyAuthRequest(t *testing.T) { assert.Nil(t, err) }) t.Run("WrongPassword", func(t *testing.T) { - user := models.User{"admin", "password", nil, false} + user := models.User{"admin", "password", nil, false, []string{}} logic.CreateUser(user) authRequest := models.UserAuthParams{"admin", "badpass"} jwt, err := logic.VerifyAuthRequest(authRequest) diff --git a/controllers/usergroups.go b/controllers/usergroups.go new file mode 100644 index 00000000..a73dc1f8 --- /dev/null +++ b/controllers/usergroups.go @@ -0,0 +1,71 @@ +package controller + +import ( + "encoding/json" + "errors" + "github.com/gravitl/netmaker/logger" + "net/http" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logic/pro" + "github.com/gravitl/netmaker/models/promodels" +) + +func userGroupsHandlers(r *mux.Router) { + r.HandleFunc("/api/usergroups", securityCheck(true, http.HandlerFunc(getUserGroups))).Methods("GET") + r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(createUserGroup))).Methods("POST") + r.HandleFunc("/api/usergroups/{usergroup}", securityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods("DELETE") +} + +func getUserGroups(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + logger.Log(1, r.Header.Get("user"), "requested fetching user groups") + + userGroups, err := pro.GetUserGroups() + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + // Returns all the groups in JSON format + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(userGroups) +} + +func createUserGroup(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var params = mux.Vars(r) + newGroup := params["usergroup"] + + logger.Log(1, r.Header.Get("user"), "requested creating user group", newGroup) + + if newGroup == "" { + returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + return + } + + err := pro.InsertUserGroup(promodels.UserGroupName(newGroup)) + if err != nil { + returnErrorResponse(w, r, formatError(err, "badrequest")) + return + } + + w.WriteHeader(http.StatusOK) +} + +func deleteUserGroup(w http.ResponseWriter, r *http.Request) { + var params = mux.Vars(r) + groupToDelete := params["usergroup"] + logger.Log(1, r.Header.Get("user"), "requested deleting user group", groupToDelete) + + if groupToDelete == "" { + returnErrorResponse(w, r, formatError(errors.New("no group name provided"), "badrequest")) + return + } + + if err := pro.DeleteUserGroup(promodels.UserGroupName(groupToDelete)); err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/database/database.go b/database/database.go index 2b3b0d4a..dea641be 100644 --- a/database/database.go +++ b/database/database.go @@ -59,6 +59,18 @@ const NODE_ACLS_TABLE_NAME = "nodeacls" // SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins const SSO_STATE_CACHE = "ssostatecache" +// METRICS_TABLE_NAME - stores network metrics +const METRICS_TABLE_NAME = "metrics" + +// NETWORK_USER_TABLE_NAME - network user table tracks stats for a network user per network +const NETWORK_USER_TABLE_NAME = "networkusers" + +// USER_GROUPS_TABLE_NAME - table for storing usergroups +const USER_GROUPS_TABLE_NAME = "usergroups" + +// CACHE_TABLE_NAME - caching table +const CACHE_TABLE_NAME = "cache" + // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -139,6 +151,10 @@ func createTables() { createTable(GENERATED_TABLE_NAME) createTable(NODE_ACLS_TABLE_NAME) createTable(SSO_STATE_CACHE) + createTable(METRICS_TABLE_NAME) + createTable(NETWORK_USER_TABLE_NAME) + createTable(USER_GROUPS_TABLE_NAME) + createTable(CACHE_TABLE_NAME) } func createTable(tableName string) error { diff --git a/docker/mosquitto.conf b/docker/mosquitto.conf index 8d3ab239..6ee92ddc 100644 --- a/docker/mosquitto.conf +++ b/docker/mosquitto.conf @@ -10,3 +10,7 @@ keyfile /mosquitto/certs/server.key listener 1883 allow_anonymous true + +listener 1884 +allow_anonymous false +password_file /etc/mosquitto.passwords diff --git a/docker/mosquitto.passwords b/docker/mosquitto.passwords new file mode 100644 index 00000000..d6966cf3 --- /dev/null +++ b/docker/mosquitto.passwords @@ -0,0 +1 @@ +netmaker-exporter:$7$101$9kcXwXP+nUMh06gm$MND2YjtRSvcZTXjMn7xYKoqUFQxG6NOgqWmXIcxxxZksM9cA8732URQWOsPHqpGEvVF9mSVagM1MBEMIKwZm2A== diff --git a/ee/LICENSE b/ee/LICENSE new file mode 100644 index 00000000..4de7fe6c --- /dev/null +++ b/ee/LICENSE @@ -0,0 +1,10 @@ +The Netmaker Enterprise license (the “Enterprise License”) +Copyright (c) 2022 Netmaker, Inc. + +With regard to the Netmaker Software: + +This software and associated documentation files (the "Software") may only be used in production, if you (and any entity that you represent) have agreed to, and are in compliance with, the Netmaker Subscription Terms of Service, available at https://netmaker.io/terms (the “Enterprise Terms”), or other agreement governing the use of the Software, as agreed by you and Netmaker, and otherwise have a valid Netmaker Enterprise license for the correct number of users, networks, nodes, servers, and external clients. Subject to the foregoing sentence, you are free to modify this Software and publish patches to the Software. You agree that Netmaker and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications and/or patches, and all such modifications and/or patches may only be used, copied, modified, displayed, distributed, or otherwise exploited with a valid Netmaker Enterprise license for the correct number of users, networks, nodes, servers, and external clients as allocated by the license. Notwithstanding the foregoing, you may copy and modify the Software for development and testing purposes, without requiring a subscription. You agree that Netmaker and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications. You are not granted any other rights beyond what is expressly stated herein. Subject to the foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For all third party components incorporated into the Netmaker Software, those components are licensed under the original license provided by the owner of the applicable component. diff --git a/ee/license.go b/ee/license.go new file mode 100644 index 00000000..f34e50f8 --- /dev/null +++ b/ee/license.go @@ -0,0 +1,216 @@ +package ee + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "math" + "net/http" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" + "github.com/gravitl/netmaker/netclient/ncutils" + "github.com/gravitl/netmaker/servercfg" +) + +// AddLicenseHooks - adds the validation and cache clear hooks +func AddLicenseHooks() { + logic.AddHook(ValidateLicense) + logic.AddHook(ClearLicenseCache) +} + +// ValidateLicense - the initial license check for netmaker server +// checks if a license is valid + limits are not exceeded +// if license is free_tier and limits exceeds, then server should terminate +// if license is not valid, server should terminate +func ValidateLicense() error { + licenseKeyValue := servercfg.GetLicenseKey() + netmakerAccountID := servercfg.GetNetmakerAccountID() + logger.Log(0, "proceeding with Netmaker license validation...") + if len(licenseKeyValue) == 0 || len(netmakerAccountID) == 0 { + logger.FatalLog(errValidation.Error()) + } + + apiPublicKey, err := getLicensePublicKey(licenseKeyValue) + if err != nil { + logger.FatalLog(errValidation.Error()) + } + + tempPubKey, tempPrivKey, err := pro.FetchApiServerKeys() + if err != nil { + logger.FatalLog(errValidation.Error()) + } + + licenseSecret := LicenseSecret{ + UserID: netmakerAccountID, + Limits: getCurrentServerLimit(), + } + + secretData, err := json.Marshal(&licenseSecret) + if err != nil { + logger.FatalLog(errValidation.Error()) + } + + encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey) + if err != nil { + logger.FatalLog(errValidation.Error()) + } + + validationResponse, err := validateLicenseKey(encryptedData, tempPubKey) + if err != nil || len(validationResponse) == 0 { + logger.FatalLog(errValidation.Error()) + } + + var licenseResponse ValidatedLicense + if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil { + logger.FatalLog(errValidation.Error()) + } + + respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey) + if err != nil { + logger.FatalLog(errValidation.Error()) + } + + license := LicenseKey{} + if err = json.Unmarshal(respData, &license); err != nil { + logger.FatalLog(errValidation.Error()) + } + + Limits.Networks = math.MaxInt + Limits.FreeTier = license.FreeTier == "yes" + Limits.Clients = license.LimitClients + Limits.Nodes = license.LimitNodes + Limits.Servers = license.LimitServers + Limits.Users = license.LimitUsers + if Limits.FreeTier { + Limits.Networks = 3 + } + + logger.Log(0, "License validation succeeded!") + return nil +} + +func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { + decodedPubKey := base64decode(licensePubKeyEncoded) + return ncutils.ConvertBytesToKey(decodedPubKey) +} + +func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, error) { + + publicKeyBytes, err := ncutils.ConvertKeyToBytes(publicKey) + if err != nil { + return nil, err + } + + msg := ValidateLicenseRequest{ + NmServerPubKey: base64encode(publicKeyBytes), + EncryptedPart: base64encode(encryptedData), + } + + requestBody, err := json.Marshal(msg) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, api_endpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, err + } + reqParams := req.URL.Query() + reqParams.Add("licensevalue", servercfg.GetLicenseKey()) + req.URL.RawQuery = reqParams.Encode() + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + client := &http.Client{} + var body []byte + validateResponse, err := client.Do(req) + if err != nil { // check cache + body, err = getCachedResponse() + if err != nil { + return nil, err + } + logger.Log(3, "proceeding with cached response, Netmaker API may be down") + } else { + defer validateResponse.Body.Close() + if validateResponse.StatusCode != 200 { + return nil, fmt.Errorf("could not validate license") + } // if you received a 200 cache the response locally + + body, err = ioutil.ReadAll(validateResponse.Body) + if err != nil { + return nil, err + } + cacheResponse(body) + } + + return body, err +} + +func cacheResponse(response []byte) error { + var lrc = licenseResponseCache{ + Body: response, + } + + record, err := json.Marshal(&lrc) + if err != nil { + return err + } + + return database.Insert(license_cache_key, string(record), database.CACHE_TABLE_NAME) +} + +func getCachedResponse() ([]byte, error) { + var lrc licenseResponseCache + record, err := database.FetchRecord(database.CACHE_TABLE_NAME, license_cache_key) + if err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(record), &lrc); err != nil { + return nil, err + } + return lrc.Body, nil +} + +// ClearLicenseCache - clears the cached validate response +func ClearLicenseCache() error { + return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key) +} + +// AddServerIDIfNotPresent - add's current server ID to DB if not present +func AddServerIDIfNotPresent() error { + currentNodeID := servercfg.GetNodeID() + currentServerIDs := serverIDs{} + + record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key) + if err != nil && !database.IsEmptyRecord(err) { + return err + } else if err == nil { + if err = json.Unmarshal([]byte(record), ¤tServerIDs); err != nil { + return err + } + } + + if !logic.StringSliceContains(currentServerIDs.ServerIDs, currentNodeID) { + currentServerIDs.ServerIDs = append(currentServerIDs.ServerIDs, currentNodeID) + data, err := json.Marshal(¤tServerIDs) + if err != nil { + return err + } + return database.Insert(server_id_key, string(data), database.SERVERCONF_TABLE_NAME) + } + + return nil +} + +func getServerCount() int { + if record, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, server_id_key); err == nil { + currentServerIDs := serverIDs{} + if err = json.Unmarshal([]byte(record), ¤tServerIDs); err == nil { + return len(currentServerIDs.ServerIDs) + } + } + return 1 +} diff --git a/ee/types.go b/ee/types.go new file mode 100644 index 00000000..d802e1af --- /dev/null +++ b/ee/types.go @@ -0,0 +1,87 @@ +package ee + +import "fmt" + +const ( + api_endpoint = "https://api.controller.netmaker.io/api/v1/license/validate" + license_cache_key = "license_response_cache" + license_validation_err_msg = "invalid license" + server_id_key = "nm-server-id" +) + +var errValidation = fmt.Errorf(license_validation_err_msg) + +// Limits - limits to be referenced throughout server +var Limits = GlobalLimits{ + Servers: 0, + Users: 0, + Nodes: 0, + Clients: 0, + FreeTier: false, +} + +// GlobalLimits - struct for holding global limits on this netmaker server in memory +type GlobalLimits struct { + Servers int + Users int + Nodes int + Clients int + FreeTier bool + Networks int +} + +// LicenseKey - the license key struct representation with associated data +type LicenseKey struct { + LicenseValue string `json:"license_value"` // actual (public) key and the unique value for the key + Expiration int64 `json:"expiration"` + LimitServers int `json:"limit_servers"` + LimitUsers int `json:"limit_users"` + LimitNodes int `json:"limit_nodes"` + LimitClients int `json:"limit_clients"` + Metadata string `json:"metadata"` + SubscriptionID string `json:"subscription_id"` // for a paid subscription (non-free-tier license) + FreeTier string `json:"free_tier"` // yes if free tier + IsActive string `json:"is_active"` // yes if active +} + +// ValidatedLicense - the validated license struct +type ValidatedLicense struct { + LicenseValue string `json:"license_value" binding:"required"` // license that validation is being requested for + EncryptedLicense string `json:"encrypted_license" binding:"required"` // to be decrypted by Netmaker using Netmaker server's private key +} + +// LicenseSecret - the encrypted struct for sending user-id +type LicenseSecret struct { + UserID string `json:"user_id" binding:"required"` // UUID for user foreign key to User table + Limits LicenseLimits `json:"limits" binding:"required"` +} + +// LicenseLimits - struct license limits +type LicenseLimits struct { + Servers int `json:"servers" binding:"required"` + Users int `json:"users" binding:"required"` + Nodes int `json:"nodes" binding:"required"` + Clients int `json:"clients" binding:"required"` +} + +// LicenseLimits.SetDefaults - sets the default values for limits +func (l *LicenseLimits) SetDefaults() { + l.Clients = 0 + l.Servers = 1 + l.Nodes = 0 + l.Users = 1 +} + +// ValidateLicenseRequest - used for request to validate license endpoint +type ValidateLicenseRequest struct { + NmServerPubKey string `json:"nm_server_pub_key" binding:"required"` // Netmaker server public key used to send data back to Netmaker for the Netmaker server to decrypt (eg output from validating license) + EncryptedPart string `json:"secret" binding:"required"` +} + +type licenseResponseCache struct { + Body []byte `json:"body" binding:"required"` +} + +type serverIDs struct { + ServerIDs []string `json:"server_ids" binding:"required"` +} diff --git a/ee/util.go b/ee/util.go new file mode 100644 index 00000000..26e6262f --- /dev/null +++ b/ee/util.go @@ -0,0 +1,54 @@ +package ee + +import ( + "encoding/base64" + + "github.com/gravitl/netmaker/logic" +) + +var isEnterprise bool + +// SetIsEnterprise - sets server to use enterprise features +func SetIsEnterprise() { + isEnterprise = true +} + +// IsEnterprise - checks if enterprise binary or not +func IsEnterprise() bool { + return isEnterprise +} + +// base64encode - base64 encode helper function +func base64encode(input []byte) string { + return base64.StdEncoding.EncodeToString(input) +} + +// base64decode - base64 decode helper function +func base64decode(input string) []byte { + + bytes, err := base64.StdEncoding.DecodeString(input) + + if err != nil { + return nil + } + + return bytes +} + +func getCurrentServerLimit() (limits LicenseLimits) { + limits.SetDefaults() + nodes, err := logic.GetAllNodes() + if err == nil { + limits.Nodes = len(nodes) + } + clients, err := logic.GetAllExtClients() + if err == nil { + limits.Clients = len(clients) + } + users, err := logic.GetUsers() + if err == nil { + limits.Users = len(users) + } + limits.Servers = getServerCount() + return +} diff --git a/go.mod b/go.mod index 03f994b7..aecdff4e 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,10 @@ require ( github.com/txn2/txeh v1.3.0 github.com/urfave/cli/v2 v2.14.1 golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd + golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094 + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect + golang.org/x/text v0.3.7 // indirect golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220324164955-056925b7df31 google.golang.org/protobuf v1.28.1 // indirect @@ -30,6 +33,7 @@ require ( fyne.io/fyne/v2 v2.2.3 github.com/c-robinson/iplib v1.0.3 github.com/cloverstd/tcping v0.1.1 + github.com/go-ping/ping v1.1.0 github.com/guumaster/hostctl v1.1.3 github.com/kr/pretty v0.3.0 github.com/posthog/posthog-go v0.0.0-20211028072449-93c17c49e2b0 @@ -37,7 +41,9 @@ require ( require ( github.com/coreos/go-oidc/v3 v3.3.0 + github.com/gorilla/websocket v1.4.2 golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e + golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 ) require ( @@ -67,7 +73,6 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.8 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect - github.com/gorilla/websocket v1.4.2 // indirect github.com/josharian/native v1.0.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kr/text v0.2.0 // indirect @@ -91,10 +96,7 @@ require ( github.com/yuin/goldmark v1.4.13 // indirect golang.org/x/image v0.0.0-20220601225756-64ec528b34cd // indirect golang.org/x/mobile v0.0.0-20211207041440-4e6c2922fdee // indirect - golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect - golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 79bb120f..8f0568a0 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec h1:3FLiRYO6PlQFDpUU7OEFlWgjGD1jnBIVSJ5SYRWk+9c= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211213063430-748e38ca8aec/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ping/ping v1.1.0 h1:3MCGhVX4fyEUuhsfwPrsEdQw6xspHkv5zHsiSoDFZYw= +github.com/go-ping/ping v1.1.0/go.mod h1:xIFjORFzTxqIV/tDVGO4eDy/bLuSyawEeojSm3GfRGk= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= @@ -251,6 +253,7 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= @@ -664,6 +667,7 @@ golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -700,6 +704,8 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdp golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/logic/accesskeys.go b/logic/accesskeys.go index 895dac18..e62599fd 100644 --- a/logic/accesskeys.go +++ b/logic/accesskeys.go @@ -161,11 +161,11 @@ func DecrimentKey(networkName string, keyvalue string) { } // IsKeyValid - check if key is valid -func IsKeyValid(networkname string, keyvalue string) bool { +func IsKeyValid(networkname string, keyvalue string) (string, bool) { network, err := GetParentNetwork(networkname) if err != nil { - return false + return "", false } accesskeys := network.AccessKeys @@ -185,7 +185,7 @@ func IsKeyValid(networkname string, keyvalue string) bool { isvalid = true } } - return isvalid + return key.Name, isvalid } // RemoveKeySensitiveInfo - remove sensitive key info diff --git a/logic/auth.go b/logic/auth.go index ba8205bc..4282f490 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -9,7 +9,9 @@ import ( "github.com/go-playground/validator/v10" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" "golang.org/x/crypto/bcrypt" ) @@ -95,8 +97,7 @@ func CreateUser(user models.User) (models.User, error) { // set password to encrypted password user.Password = string(hash) - tokenString, _ := CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) - + tokenString, _ := CreateProUserJWT(user.UserName, user.Networks, user.Groups, user.IsAdmin) if tokenString == "" { // returnErrorResponse(w, r, errorResponse) return user, err @@ -108,8 +109,45 @@ func CreateUser(user models.User) (models.User, error) { return user, err } err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) + if err != nil { + return user, err + } - return user, err + // == PRO == Add user to every network as network user == + currentNets, err := GetNetworks() + if err != nil { + currentNets = []models.Network{} + } + for i := range currentNets { + newUser := promodels.NetworkUser{ + ID: promodels.NetworkUserID(user.UserName), + Clients: []string{}, + Nodes: []string{}, + } + + pro.AddProNetDefaults(¤tNets[i]) + if pro.IsUserAllowed(¤tNets[i], user.UserName, user.Groups) { + newUser.AccessLevel = currentNets[i].ProSettings.DefaultAccessLevel + newUser.ClientLimit = currentNets[i].ProSettings.DefaultUserClientLimit + newUser.NodeLimit = currentNets[i].ProSettings.DefaultUserNodeLimit + } else { + newUser.AccessLevel = pro.NO_ACCESS + newUser.ClientLimit = 0 + newUser.NodeLimit = 0 + } + + // legacy + if StringSliceContains(user.Networks, currentNets[i].NetID) { + newUser.AccessLevel = pro.NET_ADMIN + } + userErr := pro.CreateNetworkUser(¤tNets[i], &newUser) + if userErr != nil { + logger.Log(0, "failed to add network user data on network", currentNets[i].NetID, "for user", user.UserName) + } + } + // == END PRO == + + return user, nil } // CreateAdmin - creates an admin user @@ -136,10 +174,10 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) { //Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved). record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName) if err != nil { - return "", errors.New("incorrect credentials") + return "", errors.New("error retrieving user from db: " + err.Error()) } if err = json.Unmarshal([]byte(record), &result); err != nil { - return "", errors.New("incorrect credentials") + return "", errors.New("error unmarshalling user json: " + err.Error()) } // compare password from request to stored password in database @@ -150,14 +188,15 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) { } //Create a new JWT for the node - tokenString, _ := CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin) + tokenString, _ := CreateProUserJWT(authRequest.UserName, result.Networks, result.Groups, result.IsAdmin) return tokenString, nil } // UpdateUserNetworks - updates the networks of a given user -func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.User) error { +func UpdateUserNetworks(newNetworks, newGroups []string, isadmin bool, currentUser *models.ReturnUser) error { // check if user exists - if returnedUser, err := GetUser(currentUser.UserName); err != nil { + returnedUser, err := GetUser(currentUser.UserName) + if err != nil { return err } else if returnedUser.IsAdmin { return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName) @@ -166,18 +205,46 @@ func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models. currentUser.IsAdmin = true currentUser.Networks = nil } else { + // == PRO == + currentUser.Groups = newGroups + for _, n := range newNetworks { + if !StringSliceContains(currentUser.Networks, n) { + // make net admin of any network not previously assigned + pro.MakeNetAdmin(n, currentUser.UserName) + } + } + // Compare networks, find networks not in previous + for _, n := range currentUser.Networks { + if !StringSliceContains(newNetworks, n) { + // if user was removed from a network, re-assign access to net default level + if network, err := GetNetwork(n); err == nil { + if network.ProSettings != nil { + ok := pro.AssignAccessLvl(n, currentUser.UserName, network.ProSettings.DefaultAccessLevel) + if ok { + logger.Log(0, "changed", currentUser.UserName, "access level on network", network.NetID, "to", fmt.Sprintf("%d", network.ProSettings.DefaultAccessLevel)) + } + } + } + } + } + + if err := AdjustGroupPermissions(currentUser); err != nil { + logger.Log(0, "failed to update user", currentUser.UserName, "after group update", err.Error()) + } + // == END PRO == + currentUser.Networks = newNetworks } - data, err := json.Marshal(currentUser) - if err != nil { - return err - } - if err = database.Insert(currentUser.UserName, string(data), database.USERS_TABLE_NAME); err != nil { - return err - } + _, err = UpdateUser(models.User{ + UserName: currentUser.UserName, + Networks: currentUser.Networks, + IsAdmin: currentUser.IsAdmin, + Password: "", + Groups: currentUser.Groups, + }, returnedUser) - return nil + return err } // UpdateUser - updates a given user @@ -187,11 +254,6 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) { return models.User{}, err } - err := ValidateUser(userchange) - if err != nil { - return models.User{}, err - } - queryUser := user.UserName if userchange.UserName != "" { @@ -200,6 +262,9 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) { if len(userchange.Networks) > 0 { user.Networks = userchange.Networks } + if len(userchange.Groups) > 0 { + user.Groups = userchange.Groups + } if userchange.Password != "" { // encrypt that password so we never see it again hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) @@ -212,6 +277,12 @@ func UpdateUser(userchange models.User, user models.User) (models.User, error) { user.Password = userchange.Password } + + err := ValidateUser(user) + if err != nil { + return models.User{}, err + } + if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { return models.User{}, err } @@ -256,6 +327,20 @@ func DeleteUser(user string) (bool, error) { if err != nil { return false, err } + + // == pro - remove user from all network user instances == + currentNets, err := GetNetworks() + if err != nil { + return true, err + } + + for i := range currentNets { + netID := currentNets[i].NetID + if err = pro.DeleteNetworkUser(netID, user); err != nil { + logger.Log(0, "failed to remove", user, "as network user from network", netID, err.Error()) + } + } + return true, nil } @@ -313,6 +398,9 @@ func IsStateValid(state string) (string, bool) { if s.Value != "" { delState(state) } + if err != nil { + logger.Log(2, "error retrieving oauth state:", err.Error()) + } return s.Value, err == nil } @@ -320,3 +408,51 @@ func IsStateValid(state string) (string, bool) { func delState(state string) error { return database.DeleteRecord(database.SSO_STATE_CACHE, state) } + +// PRO + +// AdjustGroupPermissions - adjusts a given user's network access based on group changes +func AdjustGroupPermissions(user *models.ReturnUser) error { + networks, err := GetNetworks() + if err != nil { + return err + } + // UPDATE + // go through all networks and see if new group is in + // if access level of current user is greater (value) than network's default + // assign network's default + // DELETE + // if user not allowed on network a + for i := range networks { + AdjustNetworkUserPermissions(user, &networks[i]) + } + + return nil +} + +// AdjustGroupPermissions - adjusts a given user's network access based on group changes +func AdjustNetworkUserPermissions(user *models.ReturnUser, network *models.Network) error { + networkUser, err := pro.GetNetworkUser( + network.NetID, + promodels.NetworkUserID(user.UserName), + ) + if err == nil && network.ProSettings != nil { + if pro.IsUserAllowed(network, user.UserName, user.Groups) { + if networkUser.AccessLevel > network.ProSettings.DefaultAccessLevel { + networkUser.AccessLevel = network.ProSettings.DefaultAccessLevel + } + if networkUser.NodeLimit < network.ProSettings.DefaultUserNodeLimit { + networkUser.NodeLimit = network.ProSettings.DefaultUserNodeLimit + } + if networkUser.ClientLimit < network.ProSettings.DefaultUserClientLimit { + networkUser.ClientLimit = network.ProSettings.DefaultUserClientLimit + } + } else { + networkUser.AccessLevel = pro.NO_ACCESS + networkUser.NodeLimit = 0 + networkUser.ClientLimit = 0 + } + pro.UpdateNetworkUser(network.NetID, networkUser) + } + return err +} diff --git a/logic/extpeers.go b/logic/extpeers.go index b7fbb299..98a56a6e 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -183,3 +183,40 @@ func UpdateExtClient(newclientid string, network string, enabled bool, client *m CreateExtClient(client) return client, err } + +// GetExtClientsByID - gets the clients of attached gateway +func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) { + var result []models.ExtClient + currentClients, err := GetNetworkExtClients(network) + if err != nil { + return result, err + } + for i := range currentClients { + if currentClients[i].IngressGatewayID == nodeid { + result = append(result, currentClients[i]) + } + } + return result, nil +} + +// GetAllExtClients - gets all ext clients from DB +func GetAllExtClients() ([]models.ExtClient, error) { + var clients = []models.ExtClient{} + currentNetworks, err := GetNetworks() + if err != nil && database.IsEmptyRecord(err) { + return clients, nil + } else if err != nil { + return clients, err + } + + for i := range currentNetworks { + netName := currentNetworks[i].NetID + netClients, err := GetNetworkExtClients(netName) + if err != nil { + continue + } + clients = append(clients, netClients...) + } + + return clients, nil +} diff --git a/logic/jwts.go b/logic/jwts.go index f91a92be..4ddad9ce 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -53,6 +53,30 @@ func CreateJWT(uuid string, macAddress string, network string) (response string, return "", err } +// CreateProUserJWT - creates a user jwt token +func CreateProUserJWT(username string, networks, groups []string, isadmin bool) (response string, err error) { + expirationTime := time.Now().Add(60 * 12 * time.Minute) + claims := &models.UserClaims{ + UserName: username, + Networks: networks, + IsAdmin: isadmin, + Groups: groups, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "Netmaker", + Subject: fmt.Sprintf("user|%s", username), + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtSecretKey) + if err == nil { + return tokenString, nil + } + return "", err +} + // CreateUserJWT - creates a user jwt token func CreateUserJWT(username string, networks []string, isadmin bool) (response string, err error) { expirationTime := time.Now().Add(60 * 12 * time.Minute) diff --git a/logic/metrics.go b/logic/metrics.go new file mode 100644 index 00000000..f60bcd76 --- /dev/null +++ b/logic/metrics.go @@ -0,0 +1,65 @@ +package logic + +import ( + "encoding/json" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" +) + +// GetMetrics - gets the metrics +func GetMetrics(nodeid string) (*models.Metrics, error) { + var metrics models.Metrics + record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid) + if err != nil { + if database.IsEmptyRecord(err) { + return &metrics, nil + } + return &metrics, err + } + err = json.Unmarshal([]byte(record), &metrics) + if err != nil { + return &metrics, err + } + return &metrics, nil +} + +// UpdateMetrics - updates the metrics of a given client +func UpdateMetrics(nodeid string, metrics *models.Metrics) error { + data, err := json.Marshal(metrics) + if err != nil { + return err + } + return database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME) +} + +// DeleteMetrics - deletes metrics of a given node +func DeleteMetrics(nodeid string) error { + return database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid) +} + +// CollectServerMetrics - collects metrics for given server node +func CollectServerMetrics(serverID string, networkNodes []models.Node) *models.Metrics { + newServerMetrics := models.Metrics{} + newServerMetrics.Connectivity = make(map[string]models.Metric) + for i := range networkNodes { + currNodeID := networkNodes[i].ID + if currNodeID == serverID { + continue + } + if currMetrics, err := GetMetrics(currNodeID); err == nil { + if currMetrics.Connectivity != nil && currMetrics.Connectivity[serverID].Connected { + metrics := currMetrics.Connectivity[serverID] + metrics.NodeName = networkNodes[i].Name + metrics.IsServer = "no" + newServerMetrics.Connectivity[currNodeID] = metrics + } + } else { + newServerMetrics.Connectivity[currNodeID] = models.Metric{ + Connected: false, + Latency: 999, + } + } + } + return &newServerMetrics +} diff --git a/logic/networks.go b/logic/networks.go index 0a3ad2dc..f82f22a0 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -13,6 +13,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic/acls/nodeacls" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/validation" @@ -62,6 +63,9 @@ func DeleteNetwork(network string) error { } else { logger.Log(1, "could not remove servers before deleting network", network) } + if err = pro.RemoveAllNetworkUsers(network); err != nil { + logger.Log(0, "failed to remove network users on network delete for network", network, err.Error()) + } return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network) } return errors.New("node check failed. All nodes must be deleted before deleting network") @@ -88,20 +92,32 @@ func CreateNetwork(network models.Network) (models.Network, error) { network.SetNodesLastModified() network.SetNetworkLastModified() + pro.AddProNetDefaults(&network) + err := ValidateNetwork(&network, false) if err != nil { //returnErrorResponse(w, r, formatError(err, "badrequest")) return models.Network{}, err } + if err = pro.InitializeNetworkUsers(network.NetID); err != nil { + return models.Network{}, err + } + data, err := json.Marshal(&network) if err != nil { return models.Network{}, err } + if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil { return models.Network{}, err } + // == add all current users to network as network users == + if err = InitializeNetUsers(&network); err != nil { + return network, err + } + return network, nil } @@ -526,25 +542,29 @@ func IsNetworkNameUnique(network *models.Network) (bool, error) { } // UpdateNetwork - updates a network with another network's fields -func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, bool, error) { +func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, bool, []string, []string, error) { if err := ValidateNetwork(newNetwork, true); err != nil { - return false, false, false, false, err + return false, false, false, false, nil, nil, err } if newNetwork.NetID == currentNetwork.NetID { hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6 localrangeupdate := newNetwork.LocalRange != currentNetwork.LocalRange hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch + groupDelta := append(StringDifference(newNetwork.ProSettings.AllowedGroups, currentNetwork.ProSettings.AllowedGroups), + StringDifference(currentNetwork.ProSettings.AllowedGroups, newNetwork.ProSettings.AllowedGroups)...) + userDelta := append(StringDifference(newNetwork.ProSettings.AllowedUsers, currentNetwork.ProSettings.AllowedUsers), + StringDifference(currentNetwork.ProSettings.AllowedUsers, newNetwork.ProSettings.AllowedUsers)...) data, err := json.Marshal(newNetwork) if err != nil { - return false, false, false, false, err + return false, false, false, false, nil, nil, err } newNetwork.SetNetworkLastModified() err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME) - return hasrangeupdate4, hasrangeupdate6, localrangeupdate, hasholepunchupdate, err + return hasrangeupdate4, hasrangeupdate6, localrangeupdate, hasholepunchupdate, groupDelta, userDelta, err } // copy values - return false, false, false, false, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.") + return false, false, false, false, nil, nil, errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.") } // GetNetwork - gets a network from database @@ -596,6 +616,15 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error { } } + if network.ProSettings != nil { + if network.ProSettings.DefaultAccessLevel < pro.NET_ADMIN || network.ProSettings.DefaultAccessLevel > pro.NO_ACCESS { + return fmt.Errorf("invalid access level") + } + if network.ProSettings.DefaultUserClientLimit < 0 || network.ProSettings.DefaultUserNodeLimit < 0 { + return fmt.Errorf("invalid node/client limit provided") + } + } + return err } diff --git a/logic/nodes.go b/logic/nodes.go index b1b6e916..84227f28 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -13,6 +13,8 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic/acls" "github.com/gravitl/netmaker/logic/acls/nodeacls" + "github.com/gravitl/netmaker/logic/pro" + "github.com/gravitl/netmaker/logic/pro/proacls" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" @@ -128,6 +130,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { } } } + nodeACLDelta := currentNode.DefaultACL != newNode.DefaultACL newNode.Fill(currentNode) if currentNode.IsServer == "yes" && !validateServer(currentNode, newNode) { @@ -137,7 +140,15 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { if err := ValidateNode(newNode, true); err != nil { return err } + if newNode.ID == currentNode.ID { + if nodeACLDelta { + if err := updateProNodeACLS(newNode); err != nil { + logger.Log(1, "failed to apply node level ACLs during creation of node", newNode.ID, "-", err.Error()) + return err + } + } + newNode.SetLastModified() if data, err := json.Marshal(newNode); err != nil { return err @@ -145,6 +156,7 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { return database.Insert(newNode.ID, string(data), database.NODES_TABLE_NAME) } } + return fmt.Errorf("failed to update node " + currentNode.ID + ", cannot change ID.") } @@ -176,9 +188,16 @@ func DeleteNodeByID(node *models.Node, exterminate bool) error { if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil { return err } + if servercfg.IsDNSMode() { SetDNS() } + if node.OwnerID != "" { + err = pro.DissociateNetworkUserNode(node.OwnerID, node.Network, node.ID) + if err != nil { + logger.Log(0, "failed to dissasociate", node.OwnerID, "from node", node.ID, ":", err.Error()) + } + } _, err = nodeacls.RemoveNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID)) if err != nil { @@ -186,6 +205,10 @@ func DeleteNodeByID(node *models.Node, exterminate bool) error { logger.Log(2, "attempted to remove node ACL for node", node.Name, node.ID) } // removeZombie <- node.ID + if err = DeleteMetrics(node.ID); err != nil { + logger.Log(1, "unable to remove metrics from DB for node", node.ID, err.Error()) + } + if node.IsServer == "yes" { return removeLocalServer(node) } @@ -219,6 +242,9 @@ func ValidateNode(node *models.Node, isUpdate bool) error { _ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool { return validation.CheckYesOrNo(fl) }) + _ = v.RegisterValidation("checkyesornoorunset", func(fl validator.FieldLevel) bool { + return validation.CheckYesOrNoOrUnset(fl) + }) err := v.Struct(node) return err @@ -255,6 +281,10 @@ func CreateNode(node *models.Node) error { } } + if node.DefaultACL == "" { + node.DefaultACL = "unset" + } + reverse := node.IsServer == "yes" if node.Address == "" { if parentNetwork.IsIPv4 == "yes" { @@ -305,9 +335,19 @@ func CreateNode(node *models.Node) error { return err } + if err = updateProNodeACLS(node); err != nil { + logger.Log(1, "failed to apply node level ACLs during creation of node", node.ID, "-", err.Error()) + return err + } + if node.IsPending != "yes" { DecrimentKey(node.Network, node.AccessKey) } + + if err = UpdateMetrics(node.ID, &models.Metrics{Connectivity: make(map[string]models.Metric)}); err != nil { + logger.Log(1, "failed to initialize metrics for node", node.Name, node.ID, err.Error()) + } + SetNetworkNodesLastModified(node.Network) if servercfg.IsDNSMode() { err = SetDNS() @@ -677,3 +717,19 @@ func findNode(ip string) (*models.Node, error) { } return nil, errors.New("node not found") } + +// == PRO == + +func updateProNodeACLS(node *models.Node) error { + // == PRO node ACLs == + networkNodes, err := GetNetworkNodes(node.Network) + if err != nil { + return err + } + if err = proacls.AdjustNodeAcls(node, networkNodes[:]); err != nil { + return err + } + return nil +} + +// == END PRO == diff --git a/logic/peers.go b/logic/peers.go index c197f3f1..16a86fe1 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -32,6 +32,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { } else if network.IsPointToSite == "yes" && node.IsHub != "yes" { isP2S = true } + var peerMap = make(models.PeerMap) // udppeers = the peers parsed from the local interface // gives us correct port to reach @@ -150,14 +151,24 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { } peers = append(peers, peerData) + peerMap[peer.PublicKey] = models.IDandAddr{ + Name: peer.Name, + ID: peer.ID, + Address: peer.PrimaryAddress(), + IsServer: peer.IsServer, + } + if peer.IsServer == "yes" { serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address}) } } if node.IsIngressGateway == "yes" { - extPeers, err := getExtPeers(node) + extPeers, idsAndAddr, err := getExtPeers(node) if err == nil { peers = append(peers, extPeers...) + for i := range idsAndAddr { + peerMap[idsAndAddr[i].ID] = idsAndAddr[i] + } } else { log.Println("ERROR RETRIEVING EXTERNAL PEERS", err) } @@ -168,14 +179,16 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) { peerUpdate.Peers = peers peerUpdate.ServerAddrs = serverNodeAddresses peerUpdate.DNS = getPeerDNS(node.Network) + peerUpdate.PeerIDs = peerMap return peerUpdate, nil } -func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, error) { +func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, []models.IDandAddr, error) { var peers []wgtypes.PeerConfig + var idsAndAddr []models.IDandAddr extPeers, err := GetExtPeersList(node) if err != nil { - return peers, err + return peers, idsAndAddr, err } for _, extPeer := range extPeers { pubkey, err := wgtypes.ParseKey(extPeer.PublicKey) @@ -209,14 +222,24 @@ func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, error) { allowedips = append(allowedips, addr6) } } + + primaryAddr := extPeer.Address + if primaryAddr == "" { + primaryAddr = extPeer.Address6 + } + peer = wgtypes.PeerConfig{ PublicKey: pubkey, ReplaceAllowedIPs: true, AllowedIPs: allowedips, } peers = append(peers, peer) + idsAndAddr = append(idsAndAddr, models.IDandAddr{ + ID: peer.PublicKey.String(), + Address: primaryAddr, + }) } - return peers, nil + return peers, idsAndAddr, nil } @@ -282,7 +305,7 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet { // handle ingress gateway peers if peer.IsIngressGateway == "yes" { - extPeers, err := getExtPeers(peer) + extPeers, _, err := getExtPeers(peer) if err != nil { logger.Log(2, "could not retrieve ext peers for ", peer.Name, err.Error()) } @@ -334,7 +357,7 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet { allowedips = append(allowedips, extAllowedIPs...) } if relayedNode.IsIngressGateway == "yes" { - extPeers, err := getExtPeers(relayedNode) + extPeers, _, err := getExtPeers(relayedNode) if err == nil { for _, extPeer := range extPeers { allowedips = append(allowedips, extPeer.AllowedIPs...) @@ -487,7 +510,7 @@ func GetPeerUpdateForRelayedNode(node *models.Node, udppeers map[string]string) } //if ingress add extclients if node.IsIngressGateway == "yes" { - extPeers, err := getExtPeers(node) + extPeers, _, err := getExtPeers(node) if err == nil { peers = append(peers, extPeers...) } else { diff --git a/logic/pro/license.go b/logic/pro/license.go new file mode 100644 index 00000000..2ca96d50 --- /dev/null +++ b/logic/pro/license.go @@ -0,0 +1,66 @@ +package pro + +import ( + "crypto/rand" + "encoding/json" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/netclient/ncutils" + "golang.org/x/crypto/nacl/box" +) + +const ( + db_license_key = "netmaker-id-key-pair" +) + +type apiServerConf struct { + PrivateKey []byte `json:"private_key" binding:"required"` + PublicKey []byte `json:"public_key" binding:"required"` +} + +// FetchApiServerKeys - fetches netmaker license keys for identification +// as well as secure communication with API +// if none present, it generates a new pair +func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { + var returnData = apiServerConf{} + currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key) + if err != nil && !database.IsEmptyRecord(err) { + return nil, nil, err + } else if database.IsEmptyRecord(err) { // need to generate a new identifier pair + pub, priv, err = box.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + pubBytes, err := ncutils.ConvertKeyToBytes(pub) + if err != nil { + return nil, nil, err + } + privBytes, err := ncutils.ConvertKeyToBytes(priv) + if err != nil { + return nil, nil, err + } + returnData.PrivateKey = privBytes + returnData.PublicKey = pubBytes + record, err := json.Marshal(&returnData) + if err != nil { + return nil, nil, err + } + if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil { + return nil, nil, err + } + } else { + if err = json.Unmarshal([]byte(currentData), &returnData); err != nil { + return nil, nil, err + } + priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey) + if err != nil { + return nil, nil, err + } + pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey) + if err != nil { + return nil, nil, err + } + } + + return pub, priv, nil +} diff --git a/logic/pro/metrics/metrics.go b/logic/pro/metrics/metrics.go new file mode 100644 index 00000000..c660b325 --- /dev/null +++ b/logic/pro/metrics/metrics.go @@ -0,0 +1,121 @@ +package metrics + +import ( + "github.com/go-ping/ping" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "golang.zx2c4.com/wireguard/wgctrl" +) + +// Collect - collects metrics +func Collect(iface string, peerMap models.PeerMap) (*models.Metrics, error) { + var metrics models.Metrics + metrics.Connectivity = make(map[string]models.Metric) + var wgclient, err = wgctrl.New() + if err != nil { + fillUnconnectedData(&metrics, peerMap) + return &metrics, err + } + defer wgclient.Close() + device, err := wgclient.Device(iface) + if err != nil { + fillUnconnectedData(&metrics, peerMap) + return &metrics, err + } + // TODO handle freebsd?? + for i := range device.Peers { + currPeer := device.Peers[i] + id := peerMap[currPeer.PublicKey.String()].ID + address := peerMap[currPeer.PublicKey.String()].Address + if id == "" || address == "" { + logger.Log(0, "attempted to parse metrics for invalid peer from server", id, address) + continue + } + var newMetric = models.Metric{ + NodeName: peerMap[currPeer.PublicKey.String()].Name, + IsServer: peerMap[currPeer.PublicKey.String()].IsServer, + } + logger.Log(2, "collecting metrics for peer", address) + newMetric.TotalReceived = currPeer.ReceiveBytes + newMetric.TotalSent = currPeer.TransmitBytes + + // get latency + pinger, err := ping.NewPinger(address) + if err != nil { + logger.Log(0, "could not initiliaze ping for metrics on peer address", address, err.Error()) + newMetric.Connected = false + newMetric.Latency = 999 + } else { + pinger.Count = 1 + err = pinger.Run() + if err != nil { + logger.Log(0, "failed ping for metrics on peer address", address, err.Error()) + newMetric.Connected = false + newMetric.Latency = 999 + } else { + pingStats := pinger.Statistics() + newMetric.Uptime = 1 + newMetric.Connected = true + newMetric.Latency = pingStats.AvgRtt.Milliseconds() + } + } + newMetric.TotalTime = 1 + metrics.Connectivity[id] = newMetric + } + + fillUnconnectedData(&metrics, peerMap) + return &metrics, nil +} + +// GetExchangedBytesForNode - get exchanged bytes for current node peers +func GetExchangedBytesForNode(node *models.Node, metrics *models.Metrics) error { + + peers, err := logic.GetPeerUpdate(node) + if err != nil { + logger.Log(0, "Failed to get peers: ", err.Error()) + return err + } + wgclient, err := wgctrl.New() + if err != nil { + return err + } + defer wgclient.Close() + device, err := wgclient.Device(node.Interface) + if err != nil { + return err + } + for _, currPeer := range device.Peers { + id := peers.PeerIDs[currPeer.PublicKey.String()].ID + address := peers.PeerIDs[currPeer.PublicKey.String()].Address + if id == "" || address == "" { + logger.Log(0, "attempted to parse metrics for invalid peer from server", id, address) + continue + } + logger.Log(2, "collecting exchanged bytes info for peer: ", address) + peerMetric := metrics.Connectivity[id] + peerMetric.TotalReceived = currPeer.ReceiveBytes + peerMetric.TotalSent = currPeer.TransmitBytes + metrics.Connectivity[id] = peerMetric + } + return nil +} + +// == used to fill zero value data for non connected peers == +func fillUnconnectedData(metrics *models.Metrics, peerMap models.PeerMap) { + for r := range peerMap { + id := peerMap[r].ID + if !metrics.Connectivity[id].Connected { + newMetric := models.Metric{ + NodeName: peerMap[r].Name, + IsServer: peerMap[r].IsServer, + Uptime: 0, + TotalTime: 1, + Connected: false, + Latency: 999, + PercentUp: 0, + } + metrics.Connectivity[id] = newMetric + } + } +} diff --git a/logic/pro/netcache/netcache.go b/logic/pro/netcache/netcache.go new file mode 100644 index 00000000..901f610c --- /dev/null +++ b/logic/pro/netcache/netcache.go @@ -0,0 +1,57 @@ +package netcache + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/gravitl/netmaker/database" +) + +const ( + expirationTime = time.Minute * 5 +) + +// CValue - the cache object for a network +type CValue struct { + Network string `json:"network"` + Value string `json:"value"` + Pass string `json:"pass"` + User string `json:"user"` + Expiration time.Time `json:"expiration"` +} + +var errExpired = fmt.Errorf("expired") + +// Set - sets a value to a key in db +func Set(k string, newValue *CValue) error { + newValue.Expiration = time.Now().Add(expirationTime) + newData, err := json.Marshal(newValue) + if err != nil { + return err + } + + return database.Insert(k, string(newData), database.CACHE_TABLE_NAME) +} + +// Get - gets a value from db, if expired, return err +func Get(k string) (*CValue, error) { + record, err := database.FetchRecord(database.CACHE_TABLE_NAME, k) + if err != nil { + return nil, err + } + var entry CValue + if err := json.Unmarshal([]byte(record), &entry); err != nil { + return nil, err + } + if time.Now().After(entry.Expiration) { + return nil, errExpired + } + + return &entry, nil +} + +// Del - deletes a value from db +func Del(k string) error { + return database.DeleteRecord(database.CACHE_TABLE_NAME, k) +} diff --git a/logic/pro/networks.go b/logic/pro/networks.go new file mode 100644 index 00000000..0b07311c --- /dev/null +++ b/logic/pro/networks.go @@ -0,0 +1,62 @@ +package pro + +import ( + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" +) + +// AddProNetDefaults - adds default values to a network model +func AddProNetDefaults(network *models.Network) { + if network.ProSettings == nil { + newProSettings := promodels.ProNetwork{ + DefaultAccessLevel: NO_ACCESS, + DefaultUserNodeLimit: 0, + DefaultUserClientLimit: 0, + AllowedUsers: []string{}, + AllowedGroups: []string{}, + } + network.ProSettings = &newProSettings + } +} + +// isUserGroupAllowed - checks if a user group is allowed on a network +func isUserGroupAllowed(network *models.Network, groupName string) bool { + if network.ProSettings != nil { + if len(network.ProSettings.AllowedGroups) > 0 { + for i := range network.ProSettings.AllowedGroups { + currentGroup := network.ProSettings.AllowedGroups[i] + if currentGroup == DEFAULT_ALLOWED_GROUPS || currentGroup == groupName { + return true + } + } + } + } + return false +} + +func isUserInAllowedUsers(network *models.Network, userName string) bool { + if network.ProSettings != nil { + if len(network.ProSettings.AllowedUsers) > 0 { + for i := range network.ProSettings.AllowedUsers { + currentUser := network.ProSettings.AllowedUsers[i] + if currentUser == DEFAULT_ALLOWED_USERS || currentUser == userName { + return true + } + } + } + } + return false +} + +// IsUserAllowed - checks if given username + groups if a user is allowed on network +func IsUserAllowed(network *models.Network, userName string, groups []string) bool { + isGroupAllowed := false + for _, g := range groups { + if isUserGroupAllowed(network, g) { + isGroupAllowed = true + break + } + } + + return isUserInAllowedUsers(network, userName) || isGroupAllowed +} diff --git a/logic/pro/networks_test.go b/logic/pro/networks_test.go new file mode 100644 index 00000000..68915a3b --- /dev/null +++ b/logic/pro/networks_test.go @@ -0,0 +1,64 @@ +package pro + +import ( + "testing" + + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" + "github.com/stretchr/testify/assert" +) + +func TestNetworkProSettings(t *testing.T) { + t.Run("Uninitialized with pro", func(t *testing.T) { + network := models.Network{ + NetID: "helloworld", + } + assert.Nil(t, network.ProSettings) + }) + t.Run("Initialized with pro", func(t *testing.T) { + network := models.Network{ + NetID: "helloworld", + } + AddProNetDefaults(&network) + assert.NotNil(t, network.ProSettings) + }) + t.Run("Net Zero Defaults set correctly with Pro", func(t *testing.T) { + network := models.Network{ + NetID: "helloworld", + } + AddProNetDefaults(&network) + assert.NotNil(t, network.ProSettings) + assert.Equal(t, NO_ACCESS, network.ProSettings.DefaultAccessLevel) + assert.Equal(t, 0, network.ProSettings.DefaultUserClientLimit) + assert.Equal(t, 0, network.ProSettings.DefaultUserNodeLimit) + }) + t.Run("Net Defaults set correctly with Pro", func(t *testing.T) { + network := models.Network{ + NetID: "helloworld", + ProSettings: &promodels.ProNetwork{ + DefaultAccessLevel: NET_ADMIN, + DefaultUserNodeLimit: 10, + DefaultUserClientLimit: 25, + }, + } + AddProNetDefaults(&network) + assert.NotNil(t, network.ProSettings) + assert.Equal(t, NET_ADMIN, network.ProSettings.DefaultAccessLevel) + assert.Equal(t, 25, network.ProSettings.DefaultUserClientLimit) + assert.Equal(t, 10, network.ProSettings.DefaultUserNodeLimit) + }) + t.Run("Net Defaults set to allow all groups/users", func(t *testing.T) { + network := models.Network{ + NetID: "helloworld", + ProSettings: &promodels.ProNetwork{ + DefaultAccessLevel: NET_ADMIN, + DefaultUserNodeLimit: 10, + DefaultUserClientLimit: 25, + }, + } + AddProNetDefaults(&network) + assert.NotNil(t, network.ProSettings) + assert.Nil(t, network.ProSettings.AllowedGroups) + assert.Nil(t, network.ProSettings.AllowedUsers) + }) +} diff --git a/logic/pro/networkuser.go b/logic/pro/networkuser.go new file mode 100644 index 00000000..49c33640 --- /dev/null +++ b/logic/pro/networkuser.go @@ -0,0 +1,247 @@ +package pro + +import ( + "encoding/json" + "fmt" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" +) + +// InitializeNetworkUsers - intializes network users for a given network +func InitializeNetworkUsers(network string) error { + newNetUserMap := make(promodels.NetworkUserMap) + netUserData, err := json.Marshal(newNetUserMap) + if err != nil { + return err + } + + return database.Insert(network, string(netUserData), database.NETWORK_USER_TABLE_NAME) +} + +// GetNetworkUsers - gets the network users table +func GetNetworkUsers(network string) (promodels.NetworkUserMap, error) { + currentUsers, err := database.FetchRecord(database.NETWORK_USER_TABLE_NAME, network) + if err != nil { + return nil, err + } + var userMap promodels.NetworkUserMap + if err = json.Unmarshal([]byte(currentUsers), &userMap); err != nil { + return nil, err + } + return userMap, nil +} + +// CreateNetworkUser - adds a network user to db +func CreateNetworkUser(network *models.Network, user *promodels.NetworkUser) error { + + if DoesNetworkUserExist(network.NetID, user.ID) { + return nil + } + + currentUsers, err := GetNetworkUsers(network.NetID) + if err != nil { + return err + } + + currentUsers.Add(user) + data, err := json.Marshal(currentUsers) + if err != nil { + return err + } + + return database.Insert(network.NetID, string(data), database.NETWORK_USER_TABLE_NAME) +} + +// DeleteNetworkUser - deletes a network user and removes from all networks +func DeleteNetworkUser(network, userid string) error { + currentUsers, err := GetNetworkUsers(network) + if err != nil { + return err + } + + currentUsers.Delete(promodels.NetworkUserID(userid)) + data, err := json.Marshal(currentUsers) + if err != nil { + return err + } + + return database.Insert(network, string(data), database.NETWORK_USER_TABLE_NAME) +} + +// DissociateNetworkUserNode - removes a node from a given user's node list +func DissociateNetworkUserNode(userid, networkid, nodeid string) error { + nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid)) + if err != nil { + return err + } + for i, n := range nuser.Nodes { + if n == nodeid { + nuser.Nodes = removeStringIndex(nuser.Nodes, i) + break + } + } + return UpdateNetworkUser(networkid, nuser) +} + +// DissociateNetworkUserClient - removes a client from a given user's client list +func DissociateNetworkUserClient(userid, networkid, clientid string) error { + nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid)) + if err != nil { + return err + } + for i, n := range nuser.Clients { + if n == clientid { + nuser.Clients = removeStringIndex(nuser.Clients, i) + break + } + } + return UpdateNetworkUser(networkid, nuser) +} + +// AssociateNetworkUserClient - removes a client from a given user's client list +func AssociateNetworkUserClient(userid, networkid, clientid string) error { + nuser, err := GetNetworkUser(networkid, promodels.NetworkUserID(userid)) + if err != nil { + return err + } + var found bool + for _, n := range nuser.Clients { + if n == clientid { + found = true + break + } + } + if found { + return nil + } else { + nuser.Clients = append(nuser.Clients, clientid) + } + + return UpdateNetworkUser(networkid, nuser) +} + +func removeStringIndex(s []string, index int) []string { + ret := make([]string, 0) + ret = append(ret, s[:index]...) + return append(ret, s[index+1:]...) +} + +// GetNetworkUser - fetches a network user from a given network +func GetNetworkUser(network string, userID promodels.NetworkUserID) (*promodels.NetworkUser, error) { + currentUsers, err := GetNetworkUsers(network) + if err != nil { + return nil, err + } + if currentUsers[userID].ID == "" { + return nil, fmt.Errorf("user %s does not exist", userID) + } + currentNetUser := currentUsers[userID] + return ¤tNetUser, nil +} + +// DoesNetworkUserExist - check if networkuser exists +func DoesNetworkUserExist(network string, userID promodels.NetworkUserID) bool { + _, err := GetNetworkUser(network, userID) + return err == nil +} + +// UpdateNetworkUser - gets a network user from given network +func UpdateNetworkUser(network string, newUser *promodels.NetworkUser) error { + currentUsers, err := GetNetworkUsers(network) + if err != nil { + return err + } + + currentUsers[newUser.ID] = *newUser + newUsersData, err := json.Marshal(¤tUsers) + if err != nil { + return err + } + + return database.Insert(network, string(newUsersData), database.NETWORK_USER_TABLE_NAME) +} + +// RemoveAllNetworkUsers - removes all network users from given network +func RemoveAllNetworkUsers(network string) error { + return database.DeleteRecord(database.NETWORK_USER_TABLE_NAME, network) +} + +// IsUserNodeAllowed - given a list of nodes, determine if the user's node is allowed based on ID +// Checks if node is in given nodes list as well as being in user's list +func IsUserNodeAllowed(nodes []models.Node, network, userID, nodeID string) bool { + + netUser, err := GetNetworkUser(network, promodels.NetworkUserID(userID)) + if err != nil { + return false + } + + for i := range nodes { + if nodes[i].ID == nodeID { + for j := range netUser.Nodes { + if netUser.Nodes[j] == nodeID { + return true + } + } + } + } + return false +} + +// IsUserClientAllowed - given a list of clients, determine if the user's client is allowed based on ID +// Checks if client is in given ext client list as well as being in user's list +func IsUserClientAllowed(clients []models.ExtClient, network, userID, clientID string) bool { + + netUser, err := GetNetworkUser(network, promodels.NetworkUserID(userID)) + if err != nil { + return false + } + + for i := range clients { + if clients[i].ClientID == clientID { + for j := range netUser.Clients { + if netUser.Clients[j] == clientID { + return true + } + } + } + } + return false +} + +// IsUserNetAdmin - checks if a user is a net admin or not +func IsUserNetAdmin(network, userID string) bool { + var isAdmin bool + user, err := GetNetworkUser(network, promodels.NetworkUserID(userID)) + if err != nil { + return isAdmin + } + return user.AccessLevel == NET_ADMIN +} + +// MakeNetAdmin - makes a given user a network admin on given network +func MakeNetAdmin(network, userID string) (ok bool) { + user, err := GetNetworkUser(network, promodels.NetworkUserID(userID)) + if err != nil { + return ok + } + user.AccessLevel = NET_ADMIN + if err = UpdateNetworkUser(network, user); err != nil { + return ok + } + return true +} + +// AssignAccessLvl - gives a user a specified access level +func AssignAccessLvl(network, userID string, accesslvl int) (ok bool) { + user, err := GetNetworkUser(network, promodels.NetworkUserID(userID)) + if err != nil { + return ok + } + user.AccessLevel = accesslvl + if err = UpdateNetworkUser(network, user); err != nil { + return ok + } + return true +} diff --git a/logic/pro/networkuser_test.go b/logic/pro/networkuser_test.go new file mode 100644 index 00000000..d3955aa1 --- /dev/null +++ b/logic/pro/networkuser_test.go @@ -0,0 +1,98 @@ +package pro + +import ( + "testing" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" + "github.com/stretchr/testify/assert" +) + +func TestNetworkUserLogic(t *testing.T) { + database.InitializeDatabase() + networkUser := promodels.NetworkUser{ + ID: "helloworld", + } + network := models.Network{ + NetID: "skynet", + AddressRange: "192.168.0.0/24", + } + nodes := []models.Node{ + models.Node{ID: "coolnode"}, + } + + clients := []models.ExtClient{ + models.ExtClient{ + ClientID: "coolclient", + }, + } + AddProNetDefaults(&network) + t.Run("Net Users initialized successfully", func(t *testing.T) { + err := InitializeNetworkUsers(network.NetID) + assert.Nil(t, err) + }) + + t.Run("Error when no network users", func(t *testing.T) { + user, err := GetNetworkUser(network.NetID, networkUser.ID) + assert.Nil(t, user) + assert.NotNil(t, err) + }) + + t.Run("Successful net user create", func(t *testing.T) { + DeleteNetworkUser(network.NetID, string(networkUser.ID)) + err := CreateNetworkUser(&network, &networkUser) + assert.Nil(t, err) + user, err := GetNetworkUser(network.NetID, networkUser.ID) + assert.NotNil(t, user) + assert.Nil(t, err) + assert.Equal(t, 0, user.AccessLevel) + assert.Equal(t, 0, user.ClientLimit) + }) + + t.Run("Successful net user update", func(t *testing.T) { + networkUser.AccessLevel = 0 + networkUser.ClientLimit = 1 + err := UpdateNetworkUser(network.NetID, &networkUser) + assert.Nil(t, err) + user, err := GetNetworkUser(network.NetID, networkUser.ID) + assert.NotNil(t, user) + assert.Nil(t, err) + assert.Equal(t, 0, user.AccessLevel) + assert.Equal(t, 1, user.ClientLimit) + }) + + t.Run("Successful net user node isallowed", func(t *testing.T) { + networkUser.Nodes = append(networkUser.Nodes, "coolnode") + err := UpdateNetworkUser(network.NetID, &networkUser) + assert.Nil(t, err) + isUserNodeAllowed := IsUserNodeAllowed(nodes[:], network.NetID, string(networkUser.ID), "coolnode") + assert.True(t, isUserNodeAllowed) + }) + + t.Run("Successful net user node not allowed", func(t *testing.T) { + isUserNodeAllowed := IsUserNodeAllowed(nodes[:], network.NetID, string(networkUser.ID), "notanode") + assert.False(t, isUserNodeAllowed) + }) + + t.Run("Successful net user client isallowed", func(t *testing.T) { + networkUser.Clients = append(networkUser.Clients, "coolclient") + err := UpdateNetworkUser(network.NetID, &networkUser) + assert.Nil(t, err) + isUserClientAllowed := IsUserClientAllowed(clients[:], network.NetID, string(networkUser.ID), "coolclient") + assert.True(t, isUserClientAllowed) + }) + + t.Run("Successful net user client not allowed", func(t *testing.T) { + isUserClientAllowed := IsUserClientAllowed(clients[:], network.NetID, string(networkUser.ID), "notaclient") + assert.False(t, isUserClientAllowed) + }) + + t.Run("Successful net user delete", func(t *testing.T) { + err := DeleteNetworkUser(network.NetID, string(networkUser.ID)) + assert.Nil(t, err) + user, err := GetNetworkUser(network.NetID, networkUser.ID) + assert.Nil(t, user) + assert.NotNil(t, err) + }) +} diff --git a/logic/pro/proacls/nodes.go b/logic/pro/proacls/nodes.go new file mode 100644 index 00000000..d55035e7 --- /dev/null +++ b/logic/pro/proacls/nodes.go @@ -0,0 +1,35 @@ +package proacls + +import ( + "github.com/gravitl/netmaker/logic/acls" + "github.com/gravitl/netmaker/logic/acls/nodeacls" + "github.com/gravitl/netmaker/models" +) + +// AdjustNodeAcls - adjusts ACLs based on a node's default value +func AdjustNodeAcls(node *models.Node, networkNodes []models.Node) error { + networkID := nodeacls.NetworkID(node.Network) + nodeID := nodeacls.NodeID(node.ID) + currentACLs, err := nodeacls.FetchAllACLs(networkID) + if err != nil { + return err + } + + for i := range networkNodes { + currentNodeID := nodeacls.NodeID(networkNodes[i].ID) + if currentNodeID == nodeID { + continue + } + // 2 cases + // both allow - allow + // either 1 denies - deny + if node.DoesACLAllow() { + currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(currentNodeID), acls.Allowed) + } else if node.DoesACLDeny() { + currentACLs.ChangeAccess(acls.AclID(nodeID), acls.AclID(currentNodeID), acls.NotAllowed) + } + } + + _, err = currentACLs.Save(acls.ContainerID(node.Network)) + return err +} diff --git a/logic/pro/types.go b/logic/pro/types.go new file mode 100644 index 00000000..d2063116 --- /dev/null +++ b/logic/pro/types.go @@ -0,0 +1,20 @@ +package pro + +const ( + // == NET ACCESS END == indicates access for system admin (control of netmaker) + // NET_ADMIN - indicates access for network admin (control of network) + NET_ADMIN = 0 + // NODE_ACCESS - indicates access for + NODE_ACCESS = 1 + // CLIENT_ACCESS - indicates access for network user (limited to nodes + ext clients) + CLIENT_ACCESS = 2 + // NO_ACCESS - indicates user has no access to network + NO_ACCESS = 3 + // == NET ACCESS END == + // DEFAULT_ALLOWED_GROUPS - default user group for all networks + DEFAULT_ALLOWED_GROUPS = "*" + // DEFAULT_ALLOWED_USERS - default allowed users for a network + DEFAULT_ALLOWED_USERS = "*" + // DB_GROUPS_KEY - represents db groups + DB_GROUPS_KEY = "netmaker-groups" +) diff --git a/logic/pro/usergroups.go b/logic/pro/usergroups.go new file mode 100644 index 00000000..e7132b3b --- /dev/null +++ b/logic/pro/usergroups.go @@ -0,0 +1,80 @@ +package pro + +import ( + "encoding/json" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models/promodels" +) + +// InitializeGroups - initialize groups data structure if not present in the DB +func InitializeGroups() error { + if !DoesUserGroupExist(DEFAULT_ALLOWED_GROUPS) { + return InsertUserGroup(DEFAULT_ALLOWED_GROUPS) + } + return nil +} + +// InsertUserGroup - inserts a group into the +func InsertUserGroup(groupName promodels.UserGroupName) error { + currentGroups, err := GetUserGroups() + if err != nil { + return err + } + currentGroups[groupName] = promodels.Void{} + newData, err := json.Marshal(¤tGroups) + if err != nil { + return err + } + return database.Insert(DB_GROUPS_KEY, string(newData), database.USER_GROUPS_TABLE_NAME) +} + +// DeleteUserGroup - deletes a group from database +func DeleteUserGroup(groupName promodels.UserGroupName) error { + var newGroups promodels.UserGroups + currentGroupRecords, err := database.FetchRecord(database.USER_GROUPS_TABLE_NAME, DB_GROUPS_KEY) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + if err = json.Unmarshal([]byte(currentGroupRecords), &newGroups); err != nil { + return err + } + delete(newGroups, groupName) + newData, err := json.Marshal(&newGroups) + if err != nil { + return err + } + return database.Insert(DB_GROUPS_KEY, string(newData), database.USER_GROUPS_TABLE_NAME) +} + +// GetUserGroups - get groups of users +func GetUserGroups() (promodels.UserGroups, error) { + var returnGroups promodels.UserGroups + groupsRecord, err := database.FetchRecord(database.USER_GROUPS_TABLE_NAME, DB_GROUPS_KEY) + if err != nil { + if database.IsEmptyRecord(err) { + return make(promodels.UserGroups, 1), nil + } + return returnGroups, err + } + + if err = json.Unmarshal([]byte(groupsRecord), &returnGroups); err != nil { + return returnGroups, err + } + + return returnGroups, nil +} + +// DoesUserGroupExist - checks if a user group exists +func DoesUserGroupExist(group promodels.UserGroupName) bool { + currentGroups, err := GetUserGroups() + if err != nil { + return true + } + for k := range currentGroups { + if k == group { + return true + } + } + return false +} diff --git a/logic/pro/usergroups_test.go b/logic/pro/usergroups_test.go new file mode 100644 index 00000000..cd472e25 --- /dev/null +++ b/logic/pro/usergroups_test.go @@ -0,0 +1,43 @@ +package pro + +import ( + "testing" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models/promodels" + "github.com/stretchr/testify/assert" +) + +func TestUserGroupLogic(t *testing.T) { + database.InitializeDatabase() + + t.Run("User Groups initialized successfully", func(t *testing.T) { + err := InitializeGroups() + assert.Nil(t, err) + }) + + t.Run("Check for default group", func(t *testing.T) { + groups, err := GetUserGroups() + assert.Nil(t, err) + var hasdefault bool + for k := range groups { + if string(k) == DEFAULT_ALLOWED_GROUPS { + hasdefault = true + } + } + assert.True(t, hasdefault) + }) + + t.Run("User Groups created successfully", func(t *testing.T) { + err := InsertUserGroup(promodels.UserGroupName("group1")) + assert.Nil(t, err) + err = InsertUserGroup(promodels.UserGroupName("group2")) + assert.Nil(t, err) + }) + + t.Run("User Groups deleted successfully", func(t *testing.T) { + err := DeleteUserGroup(promodels.UserGroupName("group1")) + assert.Nil(t, err) + assert.False(t, DoesUserGroupExist(promodels.UserGroupName("group1"))) + }) +} diff --git a/logic/users.go b/logic/users.go index 8dd7e2a1..f920905f 100644 --- a/logic/users.go +++ b/logic/users.go @@ -4,7 +4,10 @@ import ( "encoding/json" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" ) // GetUser - gets a user @@ -20,3 +23,50 @@ func GetUser(username string) (models.User, error) { } return user, err } + +// GetGroupUsers - gets users in a group +func GetGroupUsers(group string) ([]models.ReturnUser, error) { + var returnUsers []models.ReturnUser + users, err := GetUsers() + if err != nil { + return returnUsers, err + } + for _, user := range users { + if StringSliceContains(user.Groups, group) { + users = append(users, user) + } + } + return users, err +} + +// == PRO == + +// InitializeNetUsers - intializes network users for all users/networks +func InitializeNetUsers(network *models.Network) error { + // == add all current users to network as network users == + currentUsers, err := GetUsers() + if err != nil { + return err + } + + for i := range currentUsers { // add all users to given network + newUser := promodels.NetworkUser{ + ID: promodels.NetworkUserID(currentUsers[i].UserName), + Clients: []string{}, + Nodes: []string{}, + AccessLevel: pro.NO_ACCESS, + ClientLimit: 0, + NodeLimit: 0, + } + if pro.IsUserAllowed(network, currentUsers[i].UserName, currentUsers[i].Groups) { + newUser.AccessLevel = network.ProSettings.DefaultAccessLevel + newUser.ClientLimit = network.ProSettings.DefaultUserClientLimit + newUser.NodeLimit = network.ProSettings.DefaultUserNodeLimit + } + + if err = pro.CreateNetworkUser(network, &newUser); err != nil { + logger.Log(0, "failed to add network user settings to user", string(newUser.ID), "on network", network.NetID) + } + } + return nil +} diff --git a/logic/util.go b/logic/util.go index c22ce818..6fe25af8 100644 --- a/logic/util.go +++ b/logic/util.go @@ -203,3 +203,18 @@ func getNetworkProtocols(cidrs []string) (bool, bool) { } return ipv4, ipv6 } + +// StringDifference - returns the elements in `a` that aren't in `b`. +func StringDifference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} diff --git a/main.go b/main.go index e0fea725..5c50e0b0 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// -build ee package main import ( @@ -19,9 +20,11 @@ import ( "github.com/gravitl/netmaker/config" controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/netclient/ncutils" @@ -36,11 +39,11 @@ var version = "dev" func main() { absoluteConfigPath := flag.String("c", "", "absolute path to configuration file") flag.Parse() - setupConfig(*absoluteConfigPath) servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo - initialize() // initial db and acls; gen cert if required + // fmt.Println(models.ProLogo()) + initialize() // initial db and acls; gen cert if required setGarbageCollection() setVerbosity() defer database.CloseDB() @@ -73,14 +76,34 @@ func initialize() { // Client Mode Prereq Check logger.FatalLog("Error connecting to database") } logger.Log(0, "database successfully connected") - logic.SetJWTSecret() - if err = logic.AddServerIDIfNotPresent(); err != nil { + if err = ee.AddServerIDIfNotPresent(); err != nil { logger.Log(1, "failed to save server ID") } + + logic.SetJWTSecret() + + if err = pro.InitializeGroups(); err != nil { + logger.Log(0, "could not initialize default user group, \"*\"") + } + err = logic.TimerCheckpoint() if err != nil { logger.Log(1, "Timer error occurred: ", err.Error()) } + + if ee.IsEnterprise() { + // == License Handling == + ee.ValidateLicense() + if ee.Limits.FreeTier { + logger.Log(0, "proceeding with Free Tier license") + } else { + logger.Log(0, "proceeding with Paid Tier license") + } + // == End License Handling == + + ee.AddLicenseHooks() + } + var authProvider = auth.InitializeAuthProvider() if authProvider != "" { logger.Log(0, "OAuth provider,", authProvider+",", "initialized") diff --git a/main_ee.go b/main_ee.go new file mode 100644 index 00000000..ba40a39a --- /dev/null +++ b/main_ee.go @@ -0,0 +1,30 @@ +//go:build ee +// +build ee + +package main + +import ( + "github.com/gravitl/netmaker/ee" + "github.com/gravitl/netmaker/models" +) + +func init() { + ee.SetIsEnterprise() + models.SetLogo(retrieveEELogo()) +} + +func retrieveEELogo() string { + return ` + __ __ ______ ______ __ __ ______ __ __ ______ ______ +/\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \ +\ \ \-. \ \ \ __\ \/_/\ \/ \ \ \-./\ \ \ \ __ \ \ \ _"-. \ \ __\ \ \ __< + \ \_\\"\_\ \ \_____\ \ \_\ \ \_\ \ \_\ \ \_\ \_\ \ \_\ \_\ \ \_____\ \ \_\ \_\ + \/_/ \/_/ \/_____/ \/_/ \/_/ \/_/ \/_/\/_/ \/_/\/_/ \/_____/ \/_/ /_/ + + ___ ___ ____ + ____ ____ ____ / _ \ / _ \ / __ \ ____ ____ ____ + /___/ /___/ /___/ / ___/ / , _// /_/ / /___/ /___/ /___/ + /___/ /___/ /___/ /_/ /_/|_| \____/ /___/ /___/ /___/ + +` +} diff --git a/models/extclient.go b/models/extclient.go index 84593f5f..c984a538 100644 --- a/models/extclient.go +++ b/models/extclient.go @@ -13,4 +13,5 @@ type ExtClient struct { IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` LastModified int64 `json:"lastmodified" bson:"lastmodified"` Enabled bool `json:"enabled" bson:"enabled"` + OwnerID string `json:"ownerid" bson:"ownerid"` } diff --git a/models/metrics.go b/models/metrics.go new file mode 100644 index 00000000..cb8e2b74 --- /dev/null +++ b/models/metrics.go @@ -0,0 +1,45 @@ +package models + +import "time" + +// Metrics - metrics struct +type Metrics struct { + Network string `json:"network" bson:"network" yaml:"network"` + NodeID string `json:"node_id" bson:"node_id" yaml:"node_id"` + NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` + IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` + Connectivity map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"` +} + +// Metric - holds a metric for data between nodes +type Metric struct { + NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` + IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` + Uptime int64 `json:"uptime" bson:"uptime" yaml:"uptime"` + TotalTime int64 `json:"totaltime" bson:"totaltime" yaml:"totaltime"` + Latency int64 `json:"latency" bson:"latency" yaml:"latency"` + TotalReceived int64 `json:"totalreceived" bson:"totalreceived" yaml:"totalreceived"` + TotalSent int64 `json:"totalsent" bson:"totalsent" yaml:"totalsent"` + ActualUptime time.Duration `json:"actualuptime" bson:"actualuptime" yaml:"actualuptime"` + PercentUp float64 `json:"percentup" bson:"percentup" yaml:"percentup"` + Connected bool `json:"connected" bson:"connected" yaml:"connected"` +} + +// IDandAddr - struct to hold ID and primary Address +type IDandAddr struct { + ID string `json:"id" bson:"id" yaml:"id"` + Address string `json:"address" bson:"address" yaml:"address"` + Name string `json:"name" bson:"name" yaml:"name"` + IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` +} + +// PeerMap - peer map for ids and addresses in metrics +type PeerMap map[string]IDandAddr + +// MetricsMap - map for holding multiple metrics in memory +type MetricsMap map[string]Metrics + +// NetworkMetrics - metrics model for all nodes in a network +type NetworkMetrics struct { + Nodes MetricsMap `json:"nodes" bson:"nodes" yaml:"nodes"` +} diff --git a/models/mqtt.go b/models/mqtt.go index fbb4ee7b..4b52e956 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -9,6 +9,7 @@ type PeerUpdate struct { ServerAddrs []ServerAddr `json:"serveraddrs" bson:"serveraddrs" yaml:"serveraddrs"` Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` DNS string `json:"dns" bson:"dns" yaml:"dns"` + PeerIDs PeerMap `json:"peerids" bson:"peerids" yaml:"peerids"` } // KeyUpdate - key update struct diff --git a/models/names.go b/models/names.go index faa6f8cd..f5c1301f 100644 --- a/models/names.go +++ b/models/names.go @@ -231,6 +231,8 @@ var SMALL_NAMES = []string{ "cold", } +var logoString = retrieveLogo() + // GenerateNodeName - generates a random node name func GenerateNodeName() string { rand.Seed(time.Now().UnixNano()) @@ -239,6 +241,15 @@ func GenerateNodeName() string { // RetrieveLogo - retrieves the ascii art logo for Netmaker func RetrieveLogo() string { + return logoString +} + +// SetLogo - sets the logo ascii art +func SetLogo(logo string) { + logoString = logo +} + +func retrieveLogo() string { return ` __ __ ______ ______ __ __ ______ __ __ ______ ______ /\ "-.\ \ /\ ___\ /\__ _\ /\ "-./ \ /\ __ \ /\ \/ / /\ ___\ /\ == \ diff --git a/models/network.go b/models/network.go index 057a1a56..3442e84d 100644 --- a/models/network.go +++ b/models/network.go @@ -2,33 +2,36 @@ package models import ( "time" + + "github.com/gravitl/netmaker/models/promodels" ) // Network Struct - contains info for a given unique network //At some point, need to replace all instances of Name with something else like Identifier type Network struct { - AddressRange string `json:"addressrange" bson:"addressrange" validate:"omitempty,cidrv4"` - AddressRange6 string `json:"addressrange6" bson:"addressrange6" validate:"omitempty,cidrv6"` - NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=12,netid_valid"` - NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified"` - NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified"` - DefaultInterface string `json:"defaultinterface" bson:"defaultinterface" validate:"min=1,max=15"` - DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,min=1024,max=65535"` - NodeLimit int32 `json:"nodelimit" bson:"nodelimit"` - DefaultPostUp string `json:"defaultpostup" bson:"defaultpostup"` - DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"` - DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate:"omitempty,max=1000"` - AccessKeys []AccessKey `json:"accesskeys" bson:"accesskeys"` - AllowManualSignUp string `json:"allowmanualsignup" bson:"allowmanualsignup" validate:"checkyesorno"` - IsLocal string `json:"islocal" bson:"islocal" validate:"checkyesorno"` - IsIPv4 string `json:"isipv4" bson:"isipv4" validate:"checkyesorno"` - IsIPv6 string `json:"isipv6" bson:"isipv6" validate:"checkyesorno"` - IsPointToSite string `json:"ispointtosite" bson:"ispointtosite" validate:"checkyesorno"` - LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` - DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` - DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` - DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` - DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"` + AddressRange string `json:"addressrange" bson:"addressrange" validate:"omitempty,cidrv4"` + AddressRange6 string `json:"addressrange6" bson:"addressrange6" validate:"omitempty,cidrv6"` + NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=12,netid_valid"` + NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified"` + NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified"` + DefaultInterface string `json:"defaultinterface" bson:"defaultinterface" validate:"min=1,max=15"` + DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,min=1024,max=65535"` + NodeLimit int32 `json:"nodelimit" bson:"nodelimit"` + DefaultPostUp string `json:"defaultpostup" bson:"defaultpostup"` + DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"` + DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate:"omitempty,max=1000"` + AccessKeys []AccessKey `json:"accesskeys" bson:"accesskeys"` + AllowManualSignUp string `json:"allowmanualsignup" bson:"allowmanualsignup" validate:"checkyesorno"` + IsLocal string `json:"islocal" bson:"islocal" validate:"checkyesorno"` + IsIPv4 string `json:"isipv4" bson:"isipv4" validate:"checkyesorno"` + IsIPv6 string `json:"isipv6" bson:"isipv6" validate:"checkyesorno"` + IsPointToSite string `json:"ispointtosite" bson:"ispointtosite" validate:"checkyesorno"` + LocalRange string `json:"localrange" bson:"localrange" validate:"omitempty,cidr"` + DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"` + DefaultExtClientDNS string `json:"defaultextclientdns" bson:"defaultextclientdns"` + DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"` + DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"` + ProSettings *promodels.ProNetwork `json:"prosettings,omitempty" bson:"prosettings,omitempty" yaml:"prosettings,omitempty"` } // SaveData - sensitive fields of a network that should be kept the same diff --git a/models/node.go b/models/node.go index c4a477ca..0aebf0a0 100644 --- a/models/node.go +++ b/models/node.go @@ -101,6 +101,9 @@ type Node struct { FirewallInUse string `json:"firewallinuse" bson:"firewallinuse" yaml:"firewallinuse"` InternetGateway string `json:"internetgateway" bson:"internetgateway" yaml:"internetgateway"` Connected string `json:"connected" bson:"connected" yaml:"connected" validate:"checkyesorno"` + // == PRO == + DefaultACL string `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"` + OwnerID string `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"` } // NodesArray - used for node sorting @@ -438,6 +441,10 @@ func (newNode *Node) Fill(currentNode *Node) { // TODO add new field for nftable if newNode.Connected == "" { newNode.Connected = currentNode.Connected } + if newNode.DefaultACL == "" { + newNode.DefaultACL = currentNode.DefaultACL + } + newNode.TrafficKeys = currentNode.TrafficKeys } @@ -469,3 +476,15 @@ func (node *Node) NameInNodeCharSet() bool { } return true } + +// == PRO == + +// Node.DoesACLAllow - checks if default ACL on node is "yes" +func (node *Node) DoesACLAllow() bool { + return node.DefaultACL == "yes" +} + +// Node.DoesACLDeny - checks if default ACL on node is "no" +func (node *Node) DoesACLDeny() bool { + return node.DefaultACL == "no" +} diff --git a/models/promodels/networkuser.go b/models/promodels/networkuser.go new file mode 100644 index 00000000..a6865335 --- /dev/null +++ b/models/promodels/networkuser.go @@ -0,0 +1,27 @@ +package promodels + +// NetworkUserID - ID field for a network user +type NetworkUserID string + +// NetworkUser - holds fields for a network user +type NetworkUser struct { + AccessLevel int `json:"accesslevel" bson:"accesslevel" yaml:"accesslevel"` + ClientLimit int `json:"clientlimit" bson:"clientlimit" yaml:"clientlimit"` + NodeLimit int `json:"nodelimit" bson:"nodelimit" yaml:"nodelimit"` + ID NetworkUserID `json:"id" bson:"id" yaml:"id"` + Clients []string `json:"clients" bson:"clients" yaml:"clients"` + Nodes []string `json:"nodes" bson:"nodes" yaml:"nodes"` +} + +// NetworkUserMap - map of network users +type NetworkUserMap map[NetworkUserID]NetworkUser + +// NetworkUserMap.Delete - deletes a network user struct from a given map in memory +func (N NetworkUserMap) Delete(ID NetworkUserID) { + delete(N, ID) +} + +// NetworkUserMap.Add - adds a network user struct to given network user map in memory +func (N NetworkUserMap) Add(User *NetworkUser) { + N[User.ID] = *User +} diff --git a/models/promodels/pro.go b/models/promodels/pro.go new file mode 100644 index 00000000..47389aec --- /dev/null +++ b/models/promodels/pro.go @@ -0,0 +1,19 @@ +package promodels + +// ProNetwork - struct for all pro Network related fields +type ProNetwork struct { + DefaultAccessLevel int `json:"defaultaccesslevel" bson:"defaultaccesslevel" yaml:"defaultaccesslevel"` + DefaultUserNodeLimit int `json:"defaultusernodelimit" bson:"defaultusernodelimit" yaml:"defaultusernodelimit"` + DefaultUserClientLimit int `json:"defaultuserclientlimit" bson:"defaultuserclientlimit" yaml:"defaultuserclientlimit"` + AllowedUsers []string `json:"allowedusers" bson:"allowedusers" yaml:"allowedusers"` + AllowedGroups []string `json:"allowedgroups" bson:"allowedgroups" yaml:"allowedgroups"` +} + +// LoginMsg - login message struct for nodes to join via SSO login +// Need to change mac to public key for tighter verification ? +type LoginMsg struct { + Mac string `json:"mac"` + Network string `json:"network"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` +} diff --git a/models/promodels/usergroups.go b/models/promodels/usergroups.go new file mode 100644 index 00000000..e01e6e9c --- /dev/null +++ b/models/promodels/usergroups.go @@ -0,0 +1,9 @@ +package promodels + +type Void struct{} + +// UserGroupName - string representing a group name +type UserGroupName string + +// UserGroups - groups type, holds group names +type UserGroups map[UserGroupName]Void diff --git a/models/structs.go b/models/structs.go index 79f72a93..6b80e4f9 100644 --- a/models/structs.go +++ b/models/structs.go @@ -2,6 +2,7 @@ package models import ( "strings" + "time" jwt "github.com/golang-jwt/jwt/v4" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -28,6 +29,7 @@ type User struct { Password string `json:"password" bson:"password" validate:"required,min=5"` Networks []string `json:"networks" bson:"networks"` IsAdmin bool `json:"isadmin" bson:"isadmin"` + Groups []string `json:"groups" bson:"groups" yaml:"groups"` } // ReturnUser - return user struct @@ -35,6 +37,7 @@ type ReturnUser struct { UserName string `json:"username" bson:"username"` Networks []string `json:"networks" bson:"networks"` IsAdmin bool `json:"isadmin" bson:"isadmin"` + Groups []string `json:"groups" bson:"groups"` } // UserAuthParams - user auth params struct @@ -48,6 +51,7 @@ type UserClaims struct { IsAdmin bool UserName string Networks []string + Groups []string jwt.RegisteredClaims } @@ -95,10 +99,11 @@ type SuccessResponse struct { // AccessKey - access key struct type AccessKey struct { - Name string `json:"name" bson:"name" validate:"omitempty,max=20"` - Value string `json:"value" bson:"value" validate:"omitempty,alphanum,max=16"` - AccessString string `json:"accessstring" bson:"accessstring"` - Uses int `json:"uses" bson:"uses" validate:"numeric,min=0"` + Name string `json:"name" bson:"name" validate:"omitempty,max=345"` + Value string `json:"value" bson:"value" validate:"omitempty,alphanum,max=16"` + AccessString string `json:"accessstring" bson:"accessstring"` + Uses int `json:"uses" bson:"uses" validate:"numeric,min=0"` + Expiration *time.Time `json:"expiration" bson:"expiration"` } // DisplayKey - what is displayed for key @@ -200,6 +205,7 @@ type NodeGet struct { Node Node `json:"node" bson:"node" yaml:"node"` Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"` ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"` + PeerIDs PeerMap `json:"peerids,omitempty" bson:"peerids,omitempty" yaml:"peerids,omitempty"` } // ServerConfig - struct for dealing with the server information for a netclient diff --git a/mq/handlers.go b/mq/handlers.go index e00646b7..99695eb7 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -2,13 +2,17 @@ package mq import ( "encoding/json" + "fmt" + "time" mqtt "github.com/eclipse/paho.mqtt.golang" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/ee" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" + "github.com/gravitl/netmaker/servercfg" ) // DefaultHandler default message queue handler -- NOT USED @@ -93,6 +97,50 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) { }() } +// UpdateMetrics message Handler -- handles updates from client nodes for metrics +func UpdateMetrics(client mqtt.Client, msg mqtt.Message) { + if ee.IsEnterprise() { + go func() { + id, err := getID(msg.Topic()) + if err != nil { + logger.Log(1, "error getting node.ID sent on ", msg.Topic(), err.Error()) + return + } + currentNode, err := logic.GetNodeByID(id) + if err != nil { + logger.Log(1, "error getting node ", id, err.Error()) + return + } + decrypted, decryptErr := decryptMsg(¤tNode, msg.Payload()) + if decryptErr != nil { + logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error()) + return + } + + var newMetrics models.Metrics + if err := json.Unmarshal(decrypted, &newMetrics); err != nil { + logger.Log(1, "error unmarshaling payload ", err.Error()) + return + } + + updateNodeMetrics(¤tNode, &newMetrics) + + if err = logic.UpdateMetrics(id, &newMetrics); err != nil { + logger.Log(1, "faield to update node metrics", id, currentNode.Name, err.Error()) + return + } + if servercfg.IsMetricsExporter() { + if err := pushMetricsToExporter(newMetrics); err != nil { + logger.Log(2, fmt.Sprintf("failed to push node: [%s] metrics to exporter, err: %v", + currentNode.Name, err)) + } + } + + logger.Log(1, "updated node metrics", id, currentNode.Name) + }() + } +} + // ClientPeerUpdate message handler -- handles updating peers after signal from client nodes func ClientPeerUpdate(client mqtt.Client, msg mqtt.Message) { go func() { @@ -146,3 +194,46 @@ func updateNodePeers(currentNode *models.Node) { } } } + +func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) { + oldMetrics, err := logic.GetMetrics(currentNode.ID) + if err != nil { + logger.Log(1, "error finding old metrics for node", currentNode.ID, currentNode.Name) + return + } + + var attachedClients []models.ExtClient + if currentNode.IsIngressGateway == "yes" { + clients, err := logic.GetExtClientsByID(currentNode.ID, currentNode.Network) + if err == nil { + attachedClients = clients + } + } + if len(attachedClients) > 0 { + // associate ext clients with IDs + for i := range attachedClients { + extMetric := newMetrics.Connectivity[attachedClients[i].PublicKey] + delete(newMetrics.Connectivity, attachedClients[i].PublicKey) + if extMetric.Connected { // add ext client metrics + newMetrics.Connectivity[attachedClients[i].ClientID] = extMetric + } + } + } + + // run through metrics for each peer + for k := range newMetrics.Connectivity { + currMetric := newMetrics.Connectivity[k] + oldMetric := oldMetrics.Connectivity[k] + currMetric.TotalTime += oldMetric.TotalTime + currMetric.Uptime += oldMetric.Uptime // get the total uptime for this connection + currMetric.PercentUp = 100.0 * (float64(currMetric.Uptime) / float64(currMetric.TotalTime)) + totalUpMinutes := currMetric.Uptime * 5 + currMetric.ActualUptime = time.Duration(totalUpMinutes) * time.Minute + delete(oldMetrics.Connectivity, k) // remove from old data + newMetrics.Connectivity[k] = currMetric + } + + for k := range oldMetrics.Connectivity { // cleanup any left over data, self healing + delete(newMetrics.Connectivity, k) + } +} diff --git a/mq/mq.go b/mq/mq.go index edb48539..8453cb0f 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -51,6 +51,10 @@ func SetupMQTT() { client.Disconnect(240) logger.Log(0, "node client subscription failed") } + if token := client.Subscribe("metrics/#", 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil { + client.Disconnect(240) + logger.Log(0, "node metrics subscription failed") + } opts.SetOrderMatters(true) opts.SetResumeSubs(true) diff --git a/mq/publishers.go b/mq/publishers.go index 423e1461..abdc5167 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -2,11 +2,13 @@ package mq import ( "encoding/json" + "errors" "fmt" + "time" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/logic/pro/metrics" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/serverctl" @@ -105,6 +107,11 @@ func NodeUpdate(node *models.Node) error { // sendPeers - retrieve networks, send peer ports to all peers func sendPeers() { + networks, err := logic.GetNetworks() + if err != nil { + logger.Log(1, "error retrieving networks for keepalive", err.Error()) + } + var force bool peer_force_send++ if peer_force_send == 5 { @@ -121,10 +128,8 @@ func sendPeers() { if err != nil { logger.Log(3, "error occurred on timer,", err.Error()) } - } - networks, err := logic.GetNetworks() - if err != nil && !database.IsEmptyRecord(err) { - logger.Log(1, "error retrieving networks for keepalive", err.Error()) + + collectServerMetrics(networks[:]) } for _, network := range networks { @@ -176,3 +181,64 @@ func ServerStartNotify() error { } return nil } + +// function to collect and store metrics for server nodes +func collectServerMetrics(networks []models.Network) { + if len(networks) > 0 { + for i := range networks { + currentNetworkNodes, err := logic.GetNetworkNodes(networks[i].NetID) + if err != nil { + continue + } + currentServerNodes := logic.GetServerNodes(networks[i].NetID) + if len(currentServerNodes) > 0 { + for i := range currentServerNodes { + if logic.IsLocalServer(¤tServerNodes[i]) { + serverMetrics := logic.CollectServerMetrics(currentServerNodes[i].ID, currentNetworkNodes) + if serverMetrics != nil { + serverMetrics.NodeName = currentServerNodes[i].Name + serverMetrics.NodeID = currentServerNodes[i].ID + serverMetrics.IsServer = "yes" + serverMetrics.Network = currentServerNodes[i].Network + if err = metrics.GetExchangedBytesForNode(¤tServerNodes[i], serverMetrics); err != nil { + logger.Log(1, fmt.Sprintf("failed to update exchanged bytes info for server: %s, err: %v", + currentServerNodes[i].Name, err)) + } + updateNodeMetrics(¤tServerNodes[i], serverMetrics) + if err = logic.UpdateMetrics(currentServerNodes[i].ID, serverMetrics); err != nil { + logger.Log(1, "failed to update metrics for server node", currentServerNodes[i].ID) + } + if servercfg.IsMetricsExporter() { + logger.Log(2, "-------------> SERVER METRICS: ", fmt.Sprintf("%+v", serverMetrics)) + if err := pushMetricsToExporter(*serverMetrics); err != nil { + logger.Log(2, "failed to push server metrics to exporter: ", err.Error()) + } + } + + } + + } + } + } + } + } +} + +func pushMetricsToExporter(metrics models.Metrics) error { + logger.Log(2, "----> Pushing metrics to exporter") + SetupMQTT() + data, err := json.Marshal(metrics) + if err != nil { + return errors.New("failed to marshal metrics: " + err.Error()) + } + if token := mqclient.Publish("metrics_exporter", 0, true, data); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil { + var err error + if token.Error() == nil { + err = errors.New("connection timeout") + } else { + err = token.Error() + } + return err + } + return nil +} diff --git a/netclient/cli_options/flags.go b/netclient/cli_options/flags.go index 50ec8725..2887306d 100644 --- a/netclient/cli_options/flags.go +++ b/netclient/cli_options/flags.go @@ -133,6 +133,20 @@ func GetFlags(hostname string) []cli.Flag { Value: "", Usage: "Access Token for signing up machine with Netmaker server during initial 'add'.", }, + &cli.StringFlag{ + Name: "login-server", + Aliases: []string{"l"}, + EnvVars: []string{"LOGIN_SERVER"}, + Value: "", + Usage: "Login server URL, use it for the Single Sign-on along with the network parameter", + }, + &cli.StringFlag{ + Name: "user", + Aliases: []string{"u"}, + EnvVars: []string{"USER_NAME"}, + Value: "", + Usage: "User name provided upon joins if joining over basic auth is desired.", + }, &cli.StringFlag{ Name: "localrange", EnvVars: []string{"NETCLIENT_LOCALRANGE"}, diff --git a/netclient/command/commands.go b/netclient/command/commands.go index dc1195d5..cd163dc3 100644 --- a/netclient/command/commands.go +++ b/netclient/command/commands.go @@ -3,6 +3,7 @@ package command import ( "crypto/ed25519" "crypto/rand" + "errors" "fmt" "strings" @@ -18,6 +19,25 @@ import ( func Join(cfg *config.ClientConfig, privateKey string) error { var err error //join network + if cfg.SsoServer != "" { + // User wants to get access key from the OIDC server + // Do that before the Joining Network flow by performing the end point auth flow + // if performed successfully an access key is obtained from the server and then we + // proceed with the usual flow 'pretending' that user is feeded us with an access token + logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer) + err = functions.JoinViaSSo(cfg, privateKey) + if err != nil { + logger.Log(0, "Join via OIDC failed: ", err.Error()) + return err + } + + if cfg.AccessKey == "" { + return errors.New("failed to get access key") + } + logger.Log(1, "Got an access key to ", cfg.Network, " via:", cfg.SsoServer) + } + + logger.Log(1, "Joining network: ", cfg.Network) err = functions.JoinNetwork(cfg, privateKey) if err != nil { if !strings.Contains(err.Error(), "ALREADY_INSTALLED") { diff --git a/netclient/config/config.go b/netclient/config/config.go index 5a4dd6cb..5a8276ff 100644 --- a/netclient/config/config.go +++ b/netclient/config/config.go @@ -32,6 +32,7 @@ type ClientConfig struct { OperatingSystem string `yaml:"operatingsystem"` AccessKey string `yaml:"accesskey"` PublicIPService string `yaml:"publicipservice"` + SsoServer string `yaml:"sso"` } // RegisterRequest - struct for registation with netmaker server @@ -239,6 +240,11 @@ func GetCLIConfig(c *cli.Context) (ClientConfig, string, error) { if c.String("apiserver") != "" { cfg.Server.API = c.String("apiserver") } + } else if c.String("login-server") != "" { + cfg.SsoServer = c.String("login-server") + cfg.Network = c.String("network") + cfg.Node.Network = c.String("network") + global_settings.User = c.String("user") } else { cfg.AccessKey = c.String("key") cfg.Network = c.String("network") diff --git a/netclient/functions/join.go b/netclient/functions/join.go index 1b0eca60..fc53e295 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -8,21 +8,174 @@ import ( "io" "log" "net/http" + "os" + "os/signal" "runtime" + "strings" + "syscall" + "time" + "github.com/gorilla/websocket" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/models/promodels" "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/daemon" + "github.com/gravitl/netmaker/netclient/global_settings" "github.com/gravitl/netmaker/netclient/local" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/wireguard" "golang.org/x/crypto/nacl/box" + "golang.org/x/term" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// JoinViaSso - Handles the Single Sign-On flow on the end point VPN client side +// Contacts the server provided by the user (and thus specified in cfg.SsoServer) +// get the URL to authenticate with a provider and shows the user the URL. +// Then waits for user to authenticate with the URL. +// Upon user successful auth flow finished - server should return access token to the requested network +// Otherwise the error message is sent which can be displayed to the user +func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { + + // User must tell us which network he is joining + if cfg.Node.Network == "" { + return errors.New("no network provided") + } + + // Prepare a channel for interrupt + // Channel to listen for interrupt signal to terminate gracefully + interrupt := make(chan os.Signal, 1) + // Notify the interrupt channel for SIGINT + signal.Notify(interrupt, os.Interrupt) + + // Web Socket is used, construct the URL accordingly ... + socketUrl := fmt.Sprintf("wss://%s/api/oauth/node-handler", cfg.SsoServer) + // Dial the netmaker server controller + conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil) + if err != nil { + logger.Log(0, fmt.Sprintf("Error connecting to %s : %s", cfg.Server.API, err.Error())) + return err + } + // Don't forget to close when finished + defer conn.Close() + // Find and set node MacAddress + if cfg.Node.MacAddress == "" { + macs, err := ncutils.GetMacAddr() + if err != nil { + //if macaddress can't be found set to random string + cfg.Node.MacAddress = ncutils.MakeRandomString(18) + } else { + cfg.Node.MacAddress = macs[0] + } + } + + var loginMsg promodels.LoginMsg + loginMsg.Mac = cfg.Node.MacAddress + loginMsg.Network = cfg.Node.Network + if global_settings.User != "" { + fmt.Printf("Continuing with user, %s.\nPlease input password:\n", global_settings.User) + pass, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil || string(pass) == "" { + logger.FatalLog("no password provided, exiting") + } + loginMsg.User = global_settings.User + loginMsg.Password = string(pass) + } + + msgTx, err := json.Marshal(loginMsg) + if err != nil { + logger.Log(0, fmt.Sprintf("failed to marshal message %+v", loginMsg)) + return err + } + err = conn.WriteMessage(websocket.TextMessage, []byte(msgTx)) + if err != nil { + logger.FatalLog("Error during writing to websocket:", err.Error()) + return err + } + + // if user provided, server will handle authentication + if loginMsg.User == "" { + // We are going to get instructions on how to authenticate + // Wait to receive something from server + _, msg, err := conn.ReadMessage() + if err != nil { + log.Println("Error in receive:", err) + return err + } + // Print message from the netmaker controller to the user + fmt.Printf("Please visit:\n %s \n to authenticate", string(msg)) + } + + // Now the user is authenticating and we need to block until received + // An answer from the server. + // Server waits ~5 min - If takes too long timeout will be triggered by the server + done := make(chan struct{}) + // Following code will run in a separate go routine + // it reads a message from the server which either contains 'AccessToken:' string or not + // if not - then it contains an Error to display. + // if yes - then AccessToken is to be used to proceed joining the network + go func() { + defer close(done) + for { + _, msg, err := conn.ReadMessage() + if err != nil { + // Error reading a message from the server + if !strings.Contains(err.Error(), "normal") { + logger.Log(0, "read:", err.Error()) + } + return + } + // Get the access token from the response + if strings.Contains(string(msg), "AccessToken: ") { + // Access was granted + rxToken := strings.TrimPrefix(string(msg), "AccessToken: ") + accesstoken, err := config.ParseAccessToken(rxToken) + if err != nil { + log.Printf("Failed to parse received access token %s,err=%s\n", accesstoken, err.Error()) + return + } + + cfg.Network = accesstoken.ClientConfig.Network + cfg.Node.Network = accesstoken.ClientConfig.Network + cfg.AccessKey = accesstoken.ClientConfig.Key + cfg.Node.LocalRange = accesstoken.ClientConfig.LocalRange + //cfg.Server.Server = accesstoken.ServerConfig.Server + cfg.Server.API = accesstoken.APIConnString + } else { + // Access was not granted. Display a message from the server + logger.Log(0, "Message from server:", string(msg)) + cfg.AccessKey = "" + return + } + } + }() + + for { + select { + case <-done: + logger.Log(1, "finished") + return nil + case <-interrupt: + log.Println("interrupt") + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + logger.Log(0, "write close:", err.Error()) + return err + } + select { + case <-done: + case <-time.After(time.Second): + } + return nil + } + } +} + // JoinNetwork - helps a client join a network func JoinNetwork(cfg *config.ClientConfig, privateKey string) error { if cfg.Node.Network == "" { diff --git a/netclient/functions/mqpublish.go b/netclient/functions/mqpublish.go index 8ee1cf6e..5f1fdae6 100644 --- a/netclient/functions/mqpublish.go +++ b/netclient/functions/mqpublish.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" + "net/http" "os" "strconv" "sync" @@ -13,6 +15,7 @@ import ( "github.com/cloverstd/tcping/ping" "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic/pro/metrics" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/auth" "github.com/gravitl/netmaker/netclient/config" @@ -20,13 +23,16 @@ import ( "github.com/gravitl/netmaker/tls" ) +var metricsCache = new(sync.Map) + // Checkin -- go routine that checks for public or local ip changes, publishes changes // // if there are no updates, simply "pings" the server as a checkin func Checkin(ctx context.Context, wg *sync.WaitGroup) { logger.Log(2, "starting checkin goroutine") defer wg.Done() - checkin() + currentRun := 0 + checkin(currentRun) ticker := time.NewTicker(time.Second * 60) defer ticker.Stop() for { @@ -36,12 +42,16 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup) { return //delay should be configuraable -> use cfg.Node.NetworkSettings.DefaultCheckInInterval ?? case <-ticker.C: - checkin() + currentRun++ + checkin(currentRun) + if currentRun >= 5 { + currentRun = 0 + } } } } -func checkin() { +func checkin(currentRun int) { networks, _ := ncutils.GetSystemNetworks() logger.Log(3, "checkin with server(s) for all networks") for _, network := range networks { @@ -104,6 +114,10 @@ func checkin() { } Hello(&nodeCfg) checkCertExpiry(&nodeCfg) + if currentRun >= 5 { + logger.Log(0, "collecting metrics for node", nodeCfg.Node.Name) + publishMetrics(&nodeCfg) + } } } @@ -146,6 +160,78 @@ func Hello(nodeCfg *config.ClientConfig) { } } +// publishMetrics - publishes the metrics of a given nodecfg +func publishMetrics(nodeCfg *config.ClientConfig) { + token, err := Authenticate(nodeCfg) + if err != nil { + logger.Log(1, "failed to authenticate when publishing metrics", err.Error()) + return + } + url := "https://" + nodeCfg.Server.API + "/api/nodes/" + nodeCfg.Network + "/" + nodeCfg.Node.ID + response, err := API("", http.MethodGet, url, token) + if err != nil { + logger.Log(1, "failed to read from server during metrics publish", err.Error()) + return + } + if response.StatusCode != http.StatusOK { + bytes, err := io.ReadAll(response.Body) + if err != nil { + fmt.Println(err) + } + logger.Log(0, fmt.Sprintf("%s %s", string(bytes), err.Error())) + return + } + defer response.Body.Close() + var nodeGET models.NodeGet + if err := json.NewDecoder(response.Body).Decode(&nodeGET); err != nil { + logger.Log(0, "failed to decode node when running metrics update", err.Error()) + return + } + + metrics, err := metrics.Collect(nodeCfg.Node.Interface, nodeGET.PeerIDs) + if err != nil { + logger.Log(0, "failed metric collection for node", nodeCfg.Node.Name, err.Error()) + } + metrics.Network = nodeCfg.Node.Network + metrics.NodeName = nodeCfg.Node.Name + metrics.NodeID = nodeCfg.Node.ID + metrics.IsServer = "no" + data, err := json.Marshal(metrics) + if err != nil { + logger.Log(0, "something went wrong when marshalling metrics data for node", nodeCfg.Node.Name, err.Error()) + } + + if err = publish(nodeCfg, fmt.Sprintf("metrics/%s", nodeCfg.Node.ID), data, 1); err != nil { + logger.Log(0, "error occurred during publishing of metrics on node", nodeCfg.Node.Name, err.Error()) + logger.Log(0, "aggregating metrics locally until broker connection re-established") + val, ok := metricsCache.Load(nodeCfg.Node.ID) + if !ok { + metricsCache.Store(nodeCfg.Node.ID, data) + } else { + var oldMetrics models.Metrics + err = json.Unmarshal(val.([]byte), &oldMetrics) + if err == nil { + for k := range oldMetrics.Connectivity { + currentMetric := metrics.Connectivity[k] + if currentMetric.Latency == 0 { + currentMetric.Latency = oldMetrics.Connectivity[k].Latency + } + currentMetric.Uptime += oldMetrics.Connectivity[k].Uptime + currentMetric.TotalTime += oldMetrics.Connectivity[k].TotalTime + metrics.Connectivity[k] = currentMetric + } + newData, err := json.Marshal(metrics) + if err == nil { + metricsCache.Store(nodeCfg.Node.ID, newData) + } + } + } + } else { + metricsCache.Delete(nodeCfg.Node.ID) + logger.Log(0, "published metrics for node", nodeCfg.Node.Name) + } +} + // node cfg is required in order to fetch the traffic keys of that node for encryption func publish(nodeCfg *config.ClientConfig, dest string, msg []byte, qos byte) error { // setup the keys diff --git a/netclient/global_settings/globalsettings.go b/netclient/global_settings/globalsettings.go index 192c884f..086c952b 100644 --- a/netclient/global_settings/globalsettings.go +++ b/netclient/global_settings/globalsettings.go @@ -4,3 +4,6 @@ package global_settings // PublicIPServices - the list of user-specified IP services to use to obtain the node's public IP var PublicIPServices map[string]string = make(map[string]string) + +// User - holds a user string for joins when using basic auth +var User string diff --git a/scripts/nm-quick.sh b/scripts/nm-quick.sh index d1055a8b..50bfeca7 100755 --- a/scripts/nm-quick.sh +++ b/scripts/nm-quick.sh @@ -187,6 +187,8 @@ EOF echo "visit https://dashboard.$NETMAKER_BASE_DOMAIN to log in" +echo "visit https://grafana.$NETMAKER_BASE_DOMAIN to view metrics on grafana dashboard" +echo "visit https://prometheus.$NETMAKER_BASE_DOMAIN to view metrics on prometheus" sleep 7 setup_mesh() {( set -e diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 1b7c1815..262327a5 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -281,6 +281,21 @@ func IsRestBackend() bool { return isrest } +// IsMetricsExporter - checks if metrics exporter is on or off +func IsMetricsExporter() bool { + export := false + if os.Getenv("METRICS_EXPORTER") != "" { + if os.Getenv("METRICS_EXPORTER") == "on" { + export = true + } + } else if config.Config.Server.MetricsExporter != "" { + if config.Config.Server.MetricsExporter == "on" { + export = true + } + } + return export +} + // IsAgentBackend - checks if agent backed is on or off func IsAgentBackend() bool { isagent := true @@ -600,3 +615,32 @@ func GetMQServerPort() string { } return port } + +// IsBasicAuthEnabled - checks if basic auth has been configured to be turned off +func IsBasicAuthEnabled() bool { + var enabled = true //default + if os.Getenv("BASIC_AUTH") != "" { + enabled = os.Getenv("BASIC_AUTH") == "yes" + } else if config.Config.Server.BasicAuth != "" { + enabled = config.Config.Server.BasicAuth == "yes" + } + return enabled +} + +// GetLicenseKey - retrieves pro license value from env or conf files +func GetLicenseKey() string { + licenseKeyValue := os.Getenv("LICENSE_KEY") + if licenseKeyValue == "" { + licenseKeyValue = config.Config.Server.LicenseValue + } + return licenseKeyValue +} + +// GetNetmakerAccountID - get's the associated, Netmaker, account ID to verify ownership +func GetNetmakerAccountID() string { + netmakerAccountID := os.Getenv("NETMAKER_ACCOUNT_ID") + if netmakerAccountID == "" { + netmakerAccountID = config.Config.Server.LicenseValue + } + return netmakerAccountID +} diff --git a/serverctl/serverctl.go b/serverctl/serverctl.go index 8ffc76b9..4721d8c3 100644 --- a/serverctl/serverctl.go +++ b/serverctl/serverctl.go @@ -45,6 +45,9 @@ func InitServerNetclient() error { logger.Log(1, "failed pull for network", network.NetID, ", on server node", currentServerNode.ID) } } + if err = logic.InitializeNetUsers(&network); err != nil { + logger.Log(0, "something went wrong syncing usrs on network", network.NetID, "-", err.Error()) + } } } diff --git a/validation/validation.go b/validation/validation.go index 39e256b1..79746efd 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -11,6 +11,11 @@ func CheckYesOrNo(fl validator.FieldLevel) bool { return fl.Field().String() == "yes" || fl.Field().String() == "no" } +// CheckYesOrNoOrUnset - checks if a field is yes, no or unset +func CheckYesOrNoOrUnset(fl validator.FieldLevel) bool { + return CheckYesOrNo(fl) || fl.Field().String() == "unset" +} + // CheckRegex - check if a struct's field passes regex test func CheckRegex(fl validator.FieldLevel) bool { re := regexp.MustCompile(fl.Param())