diff --git a/README.md b/README.md index 04c592ce..7309d2e9 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@
- + diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..eb2b0f91 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,157 @@ +package auth + +import ( + "encoding/base64" + "encoding/json" + "net/http" + + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/crypto/bcrypt" + "golang.org/x/oauth2" +) + +// == consts == +const ( + init_provider = "initprovider" + get_user_info = "getuserinfo" + handle_callback = "handlecallback" + handle_login = "handlelogin" + google_provider_name = "google" + azure_ad_provider_name = "azure-ad" + github_provider_name = "github" + verify_user = "verifyuser" + auth_key = "netmaker_auth" +) + +var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login +var auth_provider *oauth2.Config + +func getCurrentAuthFunctions() map[string]interface{} { + var authInfo = servercfg.GetAuthProviderInfo() + var authProvider = authInfo[0] + switch authProvider { + case google_provider_name: + return google_functions + case azure_ad_provider_name: + return azure_ad_functions + case github_provider_name: + return github_functions + default: + return nil + } +} + +// InitializeAuthProvider - initializes the auth provider if any is present +func InitializeAuthProvider() string { + var functions = getCurrentAuthFunctions() + if functions == nil { + return "" + } + var _, err = fetchPassValue(logic.RandomString(64)) + if err != nil { + logic.Log(err.Error(), 0) + return "" + } + var currentFrontendURL = servercfg.GetFrontendURL() + if currentFrontendURL == "" { + return "" + } + var authInfo = servercfg.GetAuthProviderInfo() + functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2]) + return authInfo[0] +} + +// HandleAuthCallback - handles oauth callback +func HandleAuthCallback(w http.ResponseWriter, r *http.Request) { + var functions = getCurrentAuthFunctions() + if functions == nil { + return + } + functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r) +} + +// HandleAuthLogin - handles oauth login +func HandleAuthLogin(w http.ResponseWriter, r *http.Request) { + var functions = getCurrentAuthFunctions() + if functions == nil { + return + } + functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r) +} + +// IsOauthUser - returns +func IsOauthUser(user *models.User) error { + var currentValue, err = fetchPassValue("") + if err != nil { + return err + } + var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue)) + return bCryptErr +} + +// == private methods == + +func addUser(email string) error { + var hasAdmin, err = logic.HasAdmin() + if err != nil { + logic.Log("error checking for existence of admin user during OAuth login for "+email+", user not added", 1) + return err + } // generate random password to adapt to current model + var newPass, fetchErr = fetchPassValue("") + if fetchErr != nil { + return fetchErr + } + var newUser = models.User{ + UserName: email, + Password: newPass, + } + if !hasAdmin { // must be first attempt, create an admin + if newUser, err = logic.CreateAdmin(newUser); err != nil { + logic.Log("error creating admin from user, "+email+", user not added", 1) + } else { + logic.Log("admin created from user, "+email+", was first user added", 0) + } + } else { // otherwise add to db as admin..? + // TODO: add ability to add users with preemptive permissions + newUser.IsAdmin = false + if newUser, err = logic.CreateUser(newUser); err != nil { + logic.Log("error creating user, "+email+", user not added", 1) + } else { + logic.Log("user created from, "+email+"", 0) + } + } + return nil +} + +func fetchPassValue(newValue string) (string, error) { + + type valueHolder struct { + Value string `json:"value" bson:"value"` + } + var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue)) + var newValueHolder = &valueHolder{ + Value: b64NewValue, + } + var data, marshalErr = json.Marshal(newValueHolder) + if marshalErr != nil { + return "", marshalErr + } + + var currentValue, err = logic.FetchAuthSecret(auth_key, string(data)) + if err != nil { + return "", err + } + var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder) + if unmarshErr != nil { + return "", unmarshErr + } + + var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value) + if b64Err != nil { + logic.Log("could not decode pass", 0) + return "", nil + } + return string(b64CurrentValue), nil +} diff --git a/auth/azure-ad.go b/auth/azure-ad.go new file mode 100644 index 00000000..bb67b719 --- /dev/null +++ b/auth/azure-ad.go @@ -0,0 +1,126 @@ +package auth + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/oauth2" + "golang.org/x/oauth2/microsoft" +) + +var azure_ad_functions = map[string]interface{}{ + init_provider: initAzureAD, + get_user_info: getAzureUserInfo, + handle_callback: handleAzureCallback, + handle_login: handleAzureLogin, + verify_user: verifyAzureUser, +} + +type azureOauthUser struct { + UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"` + AccessToken string `json:"accesstoken" bson:"accesstoken"` +} + +// == handle azure ad authentication here == + +func initAzureAD(redirectURL string, clientID string, clientSecret string) { + auth_provider = &oauth2.Config{ + RedirectURL: redirectURL, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: []string{"User.Read"}, + Endpoint: microsoft.AzureADEndpoint(os.Getenv("AZURE_TENANT")), + } +} + +func handleAzureLogin(w http.ResponseWriter, r *http.Request) { + oauth_state_string = logic.RandomString(16) + if auth_provider == nil && servercfg.GetFrontendURL() != "" { + http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) + return + } else if auth_provider == nil { + fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func handleAzureCallback(w http.ResponseWriter, r *http.Request) { + + var content, err = getAzureUserInfo(r.FormValue("state"), r.FormValue("code")) + if err != nil { + logic.Log("error when getting user info from azure: "+err.Error(), 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + _, err = logic.GetUser(content.UserPrincipalName) + if err != nil { // user must not exists, so try to make one + if err = addUser(content.UserPrincipalName); err != nil { + return + } + } + var newPass, fetchErr = fetchPassValue("") + if fetchErr != nil { + return + } + // send a netmaker jwt token + var authRequest = models.UserAuthParams{ + UserName: content.UserPrincipalName, + Password: newPass, + } + + var jwt, jwtErr = logic.VerifyAuthRequest(authRequest) + if jwtErr != nil { + logic.Log("could not parse jwt for user "+authRequest.UserName, 1) + return + } + + logic.Log("completed azure OAuth sigin in for "+content.UserPrincipalName, 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect) +} + +func getAzureUserInfo(state string, code string) (*azureOauthUser, error) { + if state != oauth_state_string { + return nil, fmt.Errorf("invalid oauth state") + } + var token, err = auth_provider.Exchange(oauth2.NoContext, code) + if err != nil { + return nil, fmt.Errorf("code exchange failed: %s", err.Error()) + } + var data []byte + data, err = json.Marshal(token) + if err != nil { + return nil, fmt.Errorf("failed to convert token to json: %s", err.Error()) + } + var httpReq, reqErr = http.NewRequest("GET", "https://graph.microsoft.com/v1.0/me", nil) + if reqErr != nil { + return nil, fmt.Errorf("failed to create request to GitHub") + } + httpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) + response, err := http.DefaultClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("failed reading response body: %s", err.Error()) + } + var userInfo = &azureOauthUser{} + if err = json.Unmarshal(contents, userInfo); err != nil { + return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) + } + userInfo.AccessToken = string(data) + return userInfo, nil +} + +func verifyAzureUser(token *oauth2.Token) bool { + return token.Valid() +} diff --git a/auth/github.go b/auth/github.go new file mode 100644 index 00000000..552f2525 --- /dev/null +++ b/auth/github.go @@ -0,0 +1,129 @@ +package auth + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +var github_functions = map[string]interface{}{ + init_provider: initGithub, + get_user_info: getGithubUserInfo, + handle_callback: handleGithubCallback, + handle_login: handleGithubLogin, + verify_user: verifyGithubUser, +} + +type githubOauthUser struct { + Login string `json:"login" bson:"login"` + AccessToken string `json:"accesstoken" bson:"accesstoken"` +} + +// == handle github authentication here == + +func initGithub(redirectURL string, clientID string, clientSecret string) { + auth_provider = &oauth2.Config{ + RedirectURL: redirectURL, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: []string{}, + Endpoint: github.Endpoint, + } +} + +func handleGithubLogin(w http.ResponseWriter, r *http.Request) { + oauth_state_string = logic.RandomString(16) + if auth_provider == nil && servercfg.GetFrontendURL() != "" { + http.Redirect(w, r, servercfg.GetFrontendURL()+"?error=callback-error", http.StatusTemporaryRedirect) + return + } else if auth_provider == nil { + fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func handleGithubCallback(w http.ResponseWriter, r *http.Request) { + + var content, err = getGithubUserInfo(r.URL.Query().Get("state"), r.URL.Query().Get("code")) + if err != nil { + logic.Log("error when getting user info from github: "+err.Error(), 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + _, err = logic.GetUser(content.Login) + if err != nil { // user must not exist, so try to make one + if err = addUser(content.Login); err != nil { + return + } + } + var newPass, fetchErr = fetchPassValue("") + if fetchErr != nil { + return + } + // send a netmaker jwt token + var authRequest = models.UserAuthParams{ + UserName: content.Login, + Password: newPass, + } + + var jwt, jwtErr = logic.VerifyAuthRequest(authRequest) + if jwtErr != nil { + logic.Log("could not parse jwt for user "+authRequest.UserName, 1) + return + } + + logic.Log("completed github OAuth sigin in for "+content.Login, 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect) +} + +func getGithubUserInfo(state string, code string) (*githubOauthUser, error) { + if state != oauth_state_string { + return nil, fmt.Errorf("invalid OAuth state") + } + var token, err = auth_provider.Exchange(oauth2.NoContext, code) + if err != nil { + return nil, fmt.Errorf("code exchange failed: %s", err.Error()) + } + if !token.Valid() { + return nil, fmt.Errorf("GitHub code exchange yielded invalid token") + } + var data []byte + data, err = json.Marshal(token) + if err != nil { + return nil, fmt.Errorf("failed to convert token to json: %s", err.Error()) + } + var httpClient = &http.Client{} + var httpReq, reqErr = http.NewRequest("GET", "https://api.github.com/user", nil) + if reqErr != nil { + return nil, fmt.Errorf("failed to create request to GitHub") + } + httpReq.Header.Set("Authorization", "token "+token.AccessToken) + response, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("failed reading response body: %s", err.Error()) + } + var userInfo = &githubOauthUser{} + if err = json.Unmarshal(contents, userInfo); err != nil { + return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) + } + userInfo.AccessToken = string(data) + return userInfo, nil +} + +func verifyGithubUser(token *oauth2.Token) bool { + return token.Valid() +} diff --git a/auth/google.go b/auth/google.go new file mode 100644 index 00000000..91bb8030 --- /dev/null +++ b/auth/google.go @@ -0,0 +1,120 @@ +package auth + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +var google_functions = map[string]interface{}{ + init_provider: initGoogle, + get_user_info: getGoogleUserInfo, + handle_callback: handleGoogleCallback, + handle_login: handleGoogleLogin, + verify_user: verifyGoogleUser, +} + +type googleOauthUser struct { + Email string `json:"email" bson:"email"` + AccessToken string `json:"accesstoken" bson:"accesstoken"` +} + +// == handle google authentication here == + +func initGoogle(redirectURL string, clientID string, clientSecret string) { + auth_provider = &oauth2.Config{ + RedirectURL: redirectURL, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, + Endpoint: google.Endpoint, + } +} + +func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { + oauth_state_string = logic.RandomString(16) + if auth_provider == nil && servercfg.GetFrontendURL() != "" { + http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) + return + } else if auth_provider == nil { + fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { + + var content, err = getGoogleUserInfo(r.FormValue("state"), r.FormValue("code")) + if err != nil { + logic.Log("error when getting user info from google: "+err.Error(), 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + _, err = logic.GetUser(content.Email) + if err != nil { // user must not exists, so try to make one + if err = addUser(content.Email); err != nil { + return + } + } + var newPass, fetchErr = fetchPassValue("") + if fetchErr != nil { + return + } + // send a netmaker jwt token + var authRequest = models.UserAuthParams{ + UserName: content.Email, + Password: newPass, + } + + var jwt, jwtErr = logic.VerifyAuthRequest(authRequest) + if jwtErr != nil { + logic.Log("could not parse jwt for user "+authRequest.UserName, 1) + return + } + + logic.Log("completed google OAuth sigin in for "+content.Email, 1) + http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) +} + +func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) { + if state != oauth_state_string { + return nil, fmt.Errorf("invalid OAuth state") + } + var token, err = auth_provider.Exchange(oauth2.NoContext, code) + if err != nil { + return nil, fmt.Errorf("code exchange failed: %s", err.Error()) + } + var data []byte + data, err = json.Marshal(token) + if err != nil { + return nil, fmt.Errorf("failed to convert token to json: %s", err.Error()) + } + response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("failed reading response body: %s", err.Error()) + } + var userInfo = &googleOauthUser{} + if err = json.Unmarshal(contents, userInfo); err != nil { + return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) + } + userInfo.AccessToken = string(data) + return userInfo, nil +} + +func verifyGoogleUser(token *oauth2.Token) bool { + return token.Valid() +} diff --git a/config/config.go b/config/config.go index 367bbb83..c6663840 100644 --- a/config/config.go +++ b/config/config.go @@ -30,49 +30,52 @@ var Config *EnvironmentConfig // EnvironmentConfig : type EnvironmentConfig struct { Server ServerConfig `yaml:"server"` - SQL SQLConfig `yaml:"sql"` + SQL SQLConfig `yaml:"sql"` } // ServerConfig : type ServerConfig struct { - CoreDNSAddr string `yaml:"corednsaddr"` - APIConnString string `yaml:"apiconn"` - APIHost string `yaml:"apihost"` - APIPort string `yaml:"apiport"` - GRPCConnString string `yaml:"grpcconn"` - GRPCHost string `yaml:"grpchost"` - GRPCPort string `yaml:"grpcport"` - GRPCSecure string `yaml:"grpcsecure"` - MasterKey string `yaml:"masterkey"` - AllowedOrigin string `yaml:"allowedorigin"` - NodeID string `yaml:"nodeid"` - RestBackend string `yaml:"restbackend"` - AgentBackend string `yaml:"agentbackend"` - ClientMode string `yaml:"clientmode"` - DNSMode string `yaml:"dnsmode"` - SplitDNS string `yaml:"splitdns"` - DisableRemoteIPCheck string `yaml:"disableremoteipcheck"` - DisableDefaultNet string `yaml:"disabledefaultnet"` - GRPCSSL string `yaml:"grpcssl"` - Version string `yaml:"version"` - SQLConn string `yaml:"sqlconn"` - Platform string `yaml:"platform"` - Database string `yaml:database` - CheckinInterval string `yaml:checkininterval` - DefaultNodeLimit int32 `yaml:"defaultnodelimit"` - Verbosity int32 `yaml:"verbosity"` + CoreDNSAddr string `yaml:"corednsaddr"` + APIConnString string `yaml:"apiconn"` + APIHost string `yaml:"apihost"` + APIPort string `yaml:"apiport"` + GRPCConnString string `yaml:"grpcconn"` + GRPCHost string `yaml:"grpchost"` + GRPCPort string `yaml:"grpcport"` + GRPCSecure string `yaml:"grpcsecure"` + MasterKey string `yaml:"masterkey"` + AllowedOrigin string `yaml:"allowedorigin"` + NodeID string `yaml:"nodeid"` + RestBackend string `yaml:"restbackend"` + AgentBackend string `yaml:"agentbackend"` + ClientMode string `yaml:"clientmode"` + DNSMode string `yaml:"dnsmode"` + SplitDNS string `yaml:"splitdns"` + DisableRemoteIPCheck string `yaml:"disableremoteipcheck"` + DisableDefaultNet string `yaml:"disabledefaultnet"` + GRPCSSL string `yaml:"grpcssl"` + Version string `yaml:"version"` + SQLConn string `yaml:"sqlconn"` + Platform string `yaml:"platform"` + Database string `yaml:database` + CheckinInterval string `yaml:checkininterval` + DefaultNodeLimit int32 `yaml:"defaultnodelimit"` + Verbosity int32 `yaml:"verbosity"` ServerCheckinInterval int64 `yaml:"servercheckininterval"` + AuthProvider string `yaml:"authprovider"` + ClientID string `yaml:"clientid"` + ClientSecret string `yaml:"clientsecret"` + FrontendURL string `yaml:"frontendurl"` } - // Generic SQL Config type SQLConfig struct { - Host string `yaml:"host"` - Port int32 `yaml:"port"` + Host string `yaml:"host"` + Port int32 `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` - DB string `yaml:"db"` - SSLMode string `yaml:"sslmode"` + DB string `yaml:"db"` + SSLMode string `yaml:"sslmode"` } //reading in the env file diff --git a/controllers/controller.go b/controllers/controller.go index 07bce556..662e3440 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/servercfg" ) @@ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) { log.Println(err) } }() - - log.Println("REST Server successfully started on port " + port + " (REST)") + logic.Log("REST Server successfully started on port "+port+" (REST)", 0) c := make(chan os.Signal) // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) @@ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) { <-c // After receiving CTRL+C Properly stop the server - log.Println("Stopping the REST server...") + logic.Log("Stopping the REST server...", 0) srv.Shutdown(context.TODO()) - log.Println("REST Server closed.") + logic.Log("REST Server closed.", 0) } diff --git a/controllers/extClientHttpController.go b/controllers/extClientHttpController.go index 5eca78ce..5cdda0a1 100644 --- a/controllers/extClientHttpController.go +++ b/controllers/extClientHttpController.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "math/rand" "net/http" "strconv" "time" @@ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { "Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1) returnSuccessResponse(w, r, params["clientid"]+" deleted.") } - -// StringWithCharset - returns a random string in a charset -func StringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) -} - -const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -var seededRand *rand.Rand = rand.New( - rand.NewSource(time.Now().UnixNano())) diff --git a/controllers/responseHttp.go b/controllers/responseHttp.go index e44ab2e1..594d1e4d 100644 --- a/controllers/responseHttp.go +++ b/controllers/responseHttp.go @@ -2,9 +2,9 @@ package controller import ( "encoding/json" - "fmt" "net/http" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" ) @@ -48,7 +48,7 @@ func returnErrorResponse(response http.ResponseWriter, request *http.Request, er if err != nil { panic(err) } - fmt.Println(errorMessage) + logic.Log("processed request error: "+errorMessage.Message, 1) response.Header().Set("Content-Type", "application/json") response.WriteHeader(errorMessage.Code) response.Write(jsonResponse) diff --git a/controllers/userHttpController.go b/controllers/userHttpController.go index 999fa89b..c150ddeb 100644 --- a/controllers/userHttpController.go +++ b/controllers/userHttpController.go @@ -7,12 +7,12 @@ import ( "net/http" "strings" - "github.com/go-playground/validator/v10" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/auth" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/functions" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" - "golang.org/x/crypto/bcrypt" ) func userHandlers(r *mux.Router) { @@ -21,11 +21,19 @@ func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST") r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST") r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(updateUser))).Methods("PUT") + r.HandleFunc("/api/users/networks/{username}", authorizeUserAdm(http.HandlerFunc(updateUserNetworks))).Methods("PUT") r.HandleFunc("/api/users/{username}/adm", authorizeUserAdm(http.HandlerFunc(updateUserAdm))).Methods("PUT") r.HandleFunc("/api/users/{username}", authorizeUserAdm(http.HandlerFunc(createUser))).Methods("POST") r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE") r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET") r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET") + r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") + r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") + r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET") +} + +func throwOauthError(response http.ResponseWriter, request *http.Request) { + returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized")) } // Node authenticates using its password and retrieves a JWT for authorization. @@ -46,7 +54,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { return } - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) if err != nil { returnErrorResponse(response, request, formatError(err, "badrequest")) return @@ -79,35 +87,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) { response.Write(successJSONResponse) } -// VerifyAuthRequest - verifies an auth request -func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) { - var result models.User - if authRequest.UserName == "" { - return "", errors.New("username can't be empty") - } else if authRequest.Password == "" { - return "", errors.New("password can't be empty") - } - //Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved). - record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName) - if err != nil { - return "", errors.New("incorrect credentials") - } - if err = json.Unmarshal([]byte(record), &result); err != nil { - return "", errors.New("incorrect credentials") - } - - // compare password from request to stored password in database - // might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text... - // TODO: Consider a way of hashing the password client side before sending, or using certificates - if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil { - return "", errors.New("incorrect credentials") - } - - //Create a new JWT for the node - tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin) - return tokenString, nil -} - // The middleware for most requests to the API // They all pass through here first // This will validate the JWT (or check for master token) @@ -181,37 +160,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error { return nil } -// HasAdmin - checks if server has an admin -func HasAdmin() (bool, error) { - - collection, err := database.FetchRecords(database.USERS_TABLE_NAME) - if err != nil { - if database.IsEmptyRecord(err) { - return false, nil - } else { - return true, err - - } - } - for _, value := range collection { // filter for isadmin true - var user models.User - err = json.Unmarshal([]byte(value), &user) - if err != nil { - continue - } - if user.IsAdmin { - return true, nil - } - } - - return false, err -} - func hasAdmin(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - hasadmin, err := HasAdmin() + hasadmin, err := logic.HasAdmin() if err != nil { returnErrorResponse(w, r, formatError(err, "internal")) return @@ -221,20 +174,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) { } -// GetUser - gets a user -func GetUser(username string) (models.ReturnUser, error) { - - var user models.ReturnUser - record, err := database.FetchRecord(database.USERS_TABLE_NAME, username) - if err != nil { - return user, err - } - if err = json.Unmarshal([]byte(record), &user); err != nil { - return models.ReturnUser{}, err - } - return user, err -} - // GetUserInternal - gets an internal user func GetUserInternal(username string) (models.User, error) { @@ -249,30 +188,6 @@ func GetUserInternal(username string) (models.User, error) { return user, err } -// GetUsers - gets users -func GetUsers() ([]models.ReturnUser, error) { - - var users []models.ReturnUser - - collection, err := database.FetchRecords(database.USERS_TABLE_NAME) - - if err != nil { - return users, err - } - - for _, value := range collection { - - var user models.ReturnUser - err = json.Unmarshal([]byte(value), &user) - if err != nil { - continue // get users - } - users = append(users, user) - } - - return users, err -} - // Get an individual node. Nothin fancy here folks. func getUser(w http.ResponseWriter, r *http.Request) { // set header. @@ -280,7 +195,7 @@ func getUser(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) usernameFetched := params["username"] - user, err := GetUser(usernameFetched) + user, err := logic.GetUser(usernameFetched) if err != nil { returnErrorResponse(w, r, formatError(err, "internal")) @@ -295,7 +210,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) { // set header. w.Header().Set("Content-Type", "application/json") - users, err := GetUsers() + users, err := logic.GetUsers() if err != nil { returnErrorResponse(w, r, formatError(err, "internal")) @@ -306,42 +221,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(users) } -// CreateUser - creates a user -func CreateUser(user models.User) (models.User, error) { - // check if user exists - if _, err := GetUser(user.UserName); err == nil { - return models.User{}, errors.New("user exists") - } - err := ValidateUser("create", user) - if err != nil { - return models.User{}, err - } - - // encrypt that password so we never see it again - hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5) - if err != nil { - return user, err - } - // set password to encrypted password - user.Password = string(hash) - - tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) - - if tokenString == "" { - // returnErrorResponse(w, r, errorResponse) - return user, err - } - - // connect db - data, err := json.Marshal(&user) - if err != nil { - return user, err - } - err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) - - return user, err -} - func createAdmin(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -349,7 +228,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { // get node from body of request _ = json.NewDecoder(r.Body).Decode(&admin) - admin, err := CreateAdmin(admin) + admin, err := logic.CreateAdmin(admin) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) @@ -359,18 +238,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(admin) } -func CreateAdmin(admin models.User) (models.User, error) { - hasadmin, err := HasAdmin() - if err != nil { - return models.User{}, err - } - if hasadmin { - return models.User{}, errors.New("admin user already exists") - } - admin.IsAdmin = true - return CreateUser(admin) -} - func createUser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -378,7 +245,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { // get node from body of request _ = json.NewDecoder(r.Body).Decode(&user) - user, err := CreateUser(user) + user, err := logic.CreateUser(user) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) @@ -388,53 +255,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(user) } -// UpdateUser - updates a given user -func UpdateUser(userchange models.User, user models.User) (models.User, error) { - //check if user exists - if _, err := GetUser(user.UserName); err != nil { - return models.User{}, err - } - - err := ValidateUser("update", userchange) - if err != nil { - return models.User{}, err - } - - queryUser := user.UserName - - if userchange.UserName != "" { - user.UserName = userchange.UserName - } - if len(userchange.Networks) > 0 { - user.Networks = userchange.Networks - } - if userchange.Password != "" { - // encrypt that password so we never see it again - hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) - - if err != nil { - return userchange, err - } - // set password to encrypted password - userchange.Password = string(hash) - - user.Password = userchange.Password - } - if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { - return models.User{}, err - } - data, err := json.Marshal(&user) - if err != nil { - return models.User{}, err - } - if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil { - return models.User{}, err - } - functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1) - return user, nil -} - -func updateUser(w http.ResponseWriter, r *http.Request) { +func updateUserNetworks(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var params = mux.Vars(r) var user models.User @@ -452,8 +273,40 @@ func updateUser(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + + err = logic.UpdateUserNetworks(userchange.Networks, userchange.IsAdmin, &user) + if err != nil { + returnErrorResponse(w, r, formatError(err, "badrequest")) + return + } + functions.PrintUserLog(username, "status was updated", 1) + json.NewEncoder(w).Encode(user) +} + +func updateUser(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + var params = mux.Vars(r) + var user models.User + // start here + username := params["username"] + user, err := GetUserInternal(username) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } + if auth.IsOauthUser(&user) == nil { + returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user %s", username), "forbidden")) + return + } + var userchange models.User + // we decode our body request params + err = json.NewDecoder(r.Body).Decode(&userchange) + if err != nil { + returnErrorResponse(w, r, formatError(err, "internal")) + return + } userchange.Networks = nil - user, err = UpdateUser(userchange, user) + user, err = logic.UpdateUser(userchange, user) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) return @@ -473,6 +326,10 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } + if auth.IsOauthUser(&user) != nil { + returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user"), "forbidden")) + return + } var userchange models.User // we decode our body request params err = json.NewDecoder(r.Body).Decode(&userchange) @@ -480,7 +337,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { returnErrorResponse(w, r, formatError(err, "internal")) return } - user, err = UpdateUser(userchange, user) + user, err = logic.UpdateUser(userchange, user) if err != nil { returnErrorResponse(w, r, formatError(err, "badrequest")) return @@ -489,20 +346,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(user) } -// DeleteUser - deletes a given user -func DeleteUser(user string) (bool, error) { - - if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 { - return false, errors.New("user does not exist") - } - - err := database.DeleteRecord(database.USERS_TABLE_NAME, user) - if err != nil { - return false, err - } - return true, nil -} - func deleteUser(w http.ResponseWriter, r *http.Request) { // Set header w.Header().Set("Content-Type", "application/json") @@ -511,7 +354,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) username := params["username"] - success, err := DeleteUser(username) + success, err := logic.DeleteUser(username) if err != nil { returnErrorResponse(w, r, formatError(err, "internal")) @@ -524,17 +367,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { functions.PrintUserLog(username, "was deleted", 1) json.NewEncoder(w).Encode(params["username"] + " deleted.") } - -// ValidateUser - validates a user model -func ValidateUser(operation string, user models.User) error { - - v := validator.New() - err := v.Struct(user) - - if err != nil { - for _, e := range err.(validator.ValidationErrors) { - fmt.Println(e) - } - } - return err -} diff --git a/controllers/userHttpController_test.go b/controllers/userHttpController_test.go index 7efe642d..3afea694 100644 --- a/controllers/userHttpController_test.go +++ b/controllers/userHttpController_test.go @@ -4,52 +4,53 @@ import ( "testing" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" ) func deleteAllUsers() { - users, _ := GetUsers() + users, _ := logic.GetUsers() for _, user := range users { - DeleteUser(user.UserName) + logic.DeleteUser(user.UserName) } } func TestHasAdmin(t *testing.T) { //delete all current users database.InitializeDatabase() - users, _ := GetUsers() + users, _ := logic.GetUsers() for _, user := range users { - success, err := DeleteUser(user.UserName) + success, err := logic.DeleteUser(user.UserName) assert.Nil(t, err) assert.True(t, success) } t.Run("NoUser", func(t *testing.T) { - found, err := HasAdmin() + found, err := logic.HasAdmin() assert.Nil(t, err) assert.False(t, found) }) t.Run("No admin user", func(t *testing.T) { var user = models.User{"noadmin", "password", nil, false} - _, err := CreateUser(user) + _, err := logic.CreateUser(user) assert.Nil(t, err) - found, err := HasAdmin() + found, err := logic.HasAdmin() assert.Nil(t, err) assert.False(t, found) }) t.Run("admin user", func(t *testing.T) { var user = models.User{"admin", "password", nil, true} - _, err := CreateUser(user) + _, err := logic.CreateUser(user) assert.Nil(t, err) - found, err := HasAdmin() + found, err := logic.HasAdmin() assert.Nil(t, err) assert.True(t, found) }) t.Run("multiple admins", func(t *testing.T) { var user = models.User{"admin1", "password", nil, true} - _, err := CreateUser(user) + _, err := logic.CreateUser(user) assert.Nil(t, err) - found, err := HasAdmin() + found, err := logic.HasAdmin() assert.Nil(t, err) assert.True(t, found) }) @@ -60,12 +61,12 @@ func TestCreateUser(t *testing.T) { deleteAllUsers() user := models.User{"admin", "password", nil, true} t.Run("NoUser", func(t *testing.T) { - admin, err := CreateUser(user) + admin, err := logic.CreateUser(user) assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) }) t.Run("UserExists", func(t *testing.T) { - _, err := CreateUser(user) + _, err := logic.CreateUser(user) assert.NotNil(t, err) assert.EqualError(t, err, "user exists") }) @@ -78,14 +79,14 @@ func TestCreateAdmin(t *testing.T) { t.Run("NoAdmin", func(t *testing.T) { user.UserName = "admin" user.Password = "password" - admin, err := CreateAdmin(user) + admin, err := logic.CreateAdmin(user) assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) }) t.Run("AdminExists", func(t *testing.T) { user.UserName = "admin2" user.Password = "password1" - admin, err := CreateAdmin(user) + admin, err := logic.CreateAdmin(user) assert.EqualError(t, err, "admin user already exists") assert.Equal(t, admin, models.User{}) }) @@ -95,14 +96,14 @@ func TestDeleteUser(t *testing.T) { database.InitializeDatabase() deleteAllUsers() t.Run("NonExistent User", func(t *testing.T) { - deleted, err := DeleteUser("admin") + deleted, err := logic.DeleteUser("admin") assert.EqualError(t, err, "user does not exist") assert.False(t, deleted) }) t.Run("Existing User", func(t *testing.T) { user := models.User{"admin", "password", nil, true} - CreateUser(user) - deleted, err := DeleteUser("admin") + logic.CreateUser(user) + deleted, err := logic.DeleteUser("admin") assert.Nil(t, err) assert.True(t, deleted) }) @@ -114,44 +115,44 @@ func TestValidateUser(t *testing.T) { t.Run("Valid Create", func(t *testing.T) { user.UserName = "admin" user.Password = "validpass" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.Nil(t, err) }) t.Run("Valid Update", func(t *testing.T) { user.UserName = "admin" user.Password = "password" - err := ValidateUser("update", user) + err := logic.ValidateUser(user) assert.Nil(t, err) }) t.Run("Invalid UserName", func(t *testing.T) { t.Skip() user.UserName = "*invalid" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.Error(t, err) //assert.Contains(t, err.Error(), "Field validation for 'UserName' failed") }) t.Run("Short UserName", func(t *testing.T) { t.Skip() user.UserName = "1" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.NotNil(t, err) //assert.Contains(t, err.Error(), "Field validation for 'UserName' failed") }) t.Run("Empty UserName", func(t *testing.T) { t.Skip() user.UserName = "" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.EqualError(t, err, "some string") //assert.Contains(t, err.Error(), "Field validation for 'UserName' failed") }) t.Run("EmptyPassword", func(t *testing.T) { user.Password = "" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'required' tag") }) t.Run("ShortPassword", func(t *testing.T) { user.Password = "123" - err := ValidateUser("create", user) + err := logic.ValidateUser(user) assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'min' tag") }) } @@ -160,14 +161,14 @@ func TestGetUser(t *testing.T) { database.InitializeDatabase() deleteAllUsers() t.Run("NonExistantUser", func(t *testing.T) { - admin, err := GetUser("admin") + admin, err := logic.GetUser("admin") assert.EqualError(t, err, "could not find any records") assert.Equal(t, "", admin.UserName) }) t.Run("UserExisits", func(t *testing.T) { user := models.User{"admin", "password", nil, true} - CreateUser(user) - admin, err := GetUser("admin") + logic.CreateUser(user) + admin, err := logic.GetUser("admin") assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) }) @@ -183,7 +184,7 @@ func TestGetUserInternal(t *testing.T) { }) t.Run("UserExisits", func(t *testing.T) { user := models.User{"admin", "password", nil, true} - CreateUser(user) + logic.CreateUser(user) admin, err := GetUserInternal("admin") assert.Nil(t, err) assert.Equal(t, user.UserName, admin.UserName) @@ -194,21 +195,21 @@ func TestGetUsers(t *testing.T) { database.InitializeDatabase() deleteAllUsers() t.Run("NonExistantUser", func(t *testing.T) { - admin, err := GetUsers() + admin, err := logic.GetUsers() assert.EqualError(t, err, "could not find any records") assert.Equal(t, []models.ReturnUser(nil), admin) }) t.Run("UserExisits", func(t *testing.T) { user := models.User{"admin", "password", nil, true} - CreateUser(user) - admins, err := GetUsers() + logic.CreateUser(user) + admins, err := logic.GetUsers() assert.Nil(t, err) assert.Equal(t, user.UserName, admins[0].UserName) }) t.Run("MulipleUsers", func(t *testing.T) { user := models.User{"user", "password", nil, true} - CreateUser(user) - admins, err := GetUsers() + logic.CreateUser(user) + admins, err := logic.GetUsers() assert.Nil(t, err) for _, u := range admins { if u.UserName == "admin" { @@ -227,14 +228,14 @@ func TestUpdateUser(t *testing.T) { user := models.User{"admin", "password", nil, true} newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true} t.Run("NonExistantUser", func(t *testing.T) { - admin, err := UpdateUser(newuser, user) + admin, err := logic.UpdateUser(newuser, user) assert.EqualError(t, err, "could not find any records") assert.Equal(t, "", admin.UserName) }) t.Run("UserExists", func(t *testing.T) { - CreateUser(user) - admin, err := UpdateUser(newuser, user) + logic.CreateUser(user) + admin, err := logic.UpdateUser(newuser, user) assert.Nil(t, err) assert.Equal(t, newuser.UserName, admin.UserName) }) @@ -271,43 +272,43 @@ func TestVerifyAuthRequest(t *testing.T) { t.Run("EmptyUserName", func(t *testing.T) { authRequest.UserName = "" authRequest.Password = "Password" - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.Equal(t, "", jwt) assert.EqualError(t, err, "username can't be empty") }) t.Run("EmptyPassword", func(t *testing.T) { authRequest.UserName = "admin" authRequest.Password = "" - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.Equal(t, "", jwt) assert.EqualError(t, err, "password can't be empty") }) t.Run("NonExistantUser", func(t *testing.T) { authRequest.UserName = "admin" authRequest.Password = "password" - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.Equal(t, "", jwt) assert.EqualError(t, err, "incorrect credentials") }) t.Run("Non-Admin", func(t *testing.T) { user := models.User{"nonadmin", "somepass", nil, false} - CreateUser(user) + logic.CreateUser(user) authRequest := models.UserAuthParams{"nonadmin", "somepass"} - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.NotNil(t, jwt) assert.Nil(t, err) }) t.Run("WrongPassword", func(t *testing.T) { user := models.User{"admin", "password", nil, false} - CreateUser(user) + logic.CreateUser(user) authRequest := models.UserAuthParams{"admin", "badpass"} - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.Equal(t, "", jwt) assert.EqualError(t, err, "incorrect credentials") }) t.Run("Success", func(t *testing.T) { authRequest := models.UserAuthParams{"admin", "password"} - jwt, err := VerifyAuthRequest(authRequest) + jwt, err := logic.VerifyAuthRequest(authRequest) assert.Nil(t, err) assert.NotNil(t, jwt) }) diff --git a/database/database.go b/database/database.go index 2c4d50d6..7981ece2 100644 --- a/database/database.go +++ b/database/database.go @@ -39,6 +39,9 @@ const SERVERCONF_TABLE_NAME = "serverconf" // DATABASE_FILENAME - database file name const DATABASE_FILENAME = "netmaker.db" +// GENERATED_TABLE_NAME - stores server generated k/v +const GENERATED_TABLE_NAME = "generated" + // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -87,11 +90,11 @@ func getCurrentDB() map[string]interface{} { } func InitializeDatabase() error { - log.Println("connecting to", servercfg.GetDB()) + log.Println("[netmaker] connecting to", servercfg.GetDB()) tperiod := time.Now().Add(10 * time.Second) for { if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil { - log.Println("unable to connect to db, retrying . . .") + log.Println("[netmaker] unable to connect to db, retrying . . .") if time.Now().After(tperiod) { return err } @@ -114,6 +117,7 @@ func createTables() { createTable(INT_CLIENTS_TABLE_NAME) createTable(PEERS_TABLE_NAME) createTable(SERVERCONF_TABLE_NAME) + createTable(GENERATED_TABLE_NAME) } func createTable(tableName string) error { diff --git a/go.mod b/go.mod index 00453777..59748f11 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/urfave/cli/v2 v2.3.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 // indirect + golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e // indirect golang.org/x/text v0.3.7-0.20210524175448-3115f89c4b99 // indirect golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 // indirect diff --git a/logic/auth.go b/logic/auth.go new file mode 100644 index 00000000..d29a9ecf --- /dev/null +++ b/logic/auth.go @@ -0,0 +1,268 @@ +package logic + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/go-playground/validator/v10" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/functions" + "github.com/gravitl/netmaker/models" + "golang.org/x/crypto/bcrypt" +) + +// HasAdmin - checks if server has an admin +func HasAdmin() (bool, error) { + + collection, err := database.FetchRecords(database.USERS_TABLE_NAME) + if err != nil { + if database.IsEmptyRecord(err) { + return false, nil + } else { + return true, err + } + } + for _, value := range collection { // filter for isadmin true + var user models.User + err = json.Unmarshal([]byte(value), &user) + if err != nil { + continue + } + if user.IsAdmin { + return true, nil + } + } + + return false, err +} + +// GetUser - gets a user +func GetUser(username string) (models.ReturnUser, error) { + + var user models.ReturnUser + record, err := database.FetchRecord(database.USERS_TABLE_NAME, username) + if err != nil { + return user, err + } + if err = json.Unmarshal([]byte(record), &user); err != nil { + return models.ReturnUser{}, err + } + return user, err +} + +// GetUsers - gets users +func GetUsers() ([]models.ReturnUser, error) { + + var users []models.ReturnUser + + collection, err := database.FetchRecords(database.USERS_TABLE_NAME) + + if err != nil { + return users, err + } + + for _, value := range collection { + + var user models.ReturnUser + err = json.Unmarshal([]byte(value), &user) + if err != nil { + continue // get users + } + users = append(users, user) + } + + return users, err +} + +// CreateUser - creates a user +func CreateUser(user models.User) (models.User, error) { + // check if user exists + if _, err := GetUser(user.UserName); err == nil { + return models.User{}, errors.New("user exists") + } + var err = ValidateUser(user) + if err != nil { + return models.User{}, err + } + + // encrypt that password so we never see it again + hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5) + if err != nil { + return user, err + } + // set password to encrypted password + user.Password = string(hash) + + tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) + + if tokenString == "" { + // returnErrorResponse(w, r, errorResponse) + return user, err + } + + // connect db + data, err := json.Marshal(&user) + if err != nil { + return user, err + } + err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) + + return user, err +} + +// CreateAdmin - creates an admin user +func CreateAdmin(admin models.User) (models.User, error) { + hasadmin, err := HasAdmin() + if err != nil { + return models.User{}, err + } + if hasadmin { + return models.User{}, errors.New("admin user already exists") + } + admin.IsAdmin = true + return CreateUser(admin) +} + +// VerifyAuthRequest - verifies an auth request +func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) { + var result models.User + if authRequest.UserName == "" { + return "", errors.New("username can't be empty") + } else if authRequest.Password == "" { + return "", errors.New("password can't be empty") + } + //Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved). + record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName) + if err != nil { + return "", errors.New("incorrect credentials") + } + if err = json.Unmarshal([]byte(record), &result); err != nil { + return "", errors.New("incorrect credentials") + } + + // compare password from request to stored password in database + // might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text... + // TODO: Consider a way of hashing the password client side before sending, or using certificates + if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil { + return "", errors.New("incorrect credentials") + } + + //Create a new JWT for the node + tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin) + return tokenString, nil +} + +// UpdateUserNetworks - updates the networks of a given user +func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.User) error { + // check if user exists + if returnedUser, err := GetUser(currentUser.UserName); err != nil { + return err + } else if returnedUser.IsAdmin { + return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName) + } + if isadmin { + currentUser.IsAdmin = true + currentUser.Networks = nil + } else { + currentUser.Networks = newNetworks + } + + data, err := json.Marshal(currentUser) + if err != nil { + return err + } + if err = database.Insert(currentUser.UserName, string(data), database.USERS_TABLE_NAME); err != nil { + return err + } + + return nil +} + +// UpdateUser - updates a given user +func UpdateUser(userchange models.User, user models.User) (models.User, error) { + //check if user exists + if _, err := GetUser(user.UserName); err != nil { + return models.User{}, err + } + + err := ValidateUser(userchange) + if err != nil { + return models.User{}, err + } + + queryUser := user.UserName + + if userchange.UserName != "" { + user.UserName = userchange.UserName + } + if len(userchange.Networks) > 0 { + user.Networks = userchange.Networks + } + if userchange.Password != "" { + // encrypt that password so we never see it again + hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) + + if err != nil { + return userchange, err + } + // set password to encrypted password + userchange.Password = string(hash) + + user.Password = userchange.Password + } + if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { + return models.User{}, err + } + data, err := json.Marshal(&user) + if err != nil { + return models.User{}, err + } + if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil { + return models.User{}, err + } + Log("updated user "+queryUser, 1) + return user, nil +} + +// ValidateUser - validates a user model +func ValidateUser(user models.User) error { + + v := validator.New() + err := v.Struct(user) + + if err != nil { + for _, e := range err.(validator.ValidationErrors) { + Log(e.Error(), 2) + } + } + + return err +} + +// DeleteUser - deletes a given user +func DeleteUser(user string) (bool, error) { + + if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 { + return false, errors.New("user does not exist") + } + + err := database.DeleteRecord(database.USERS_TABLE_NAME, user) + if err != nil { + return false, err + } + return true, nil +} + +// FetchAuthSecret - manages secrets for oauth +func FetchAuthSecret(key string, secret string) (string, error) { + var record, err = database.FetchRecord(database.GENERATED_TABLE_NAME, key) + if err != nil { + if err = database.Insert(key, secret, database.GENERATED_TABLE_NAME); err != nil { + return "", err + } else { + return secret, nil + } + } + return record, nil +} diff --git a/logic/network.go b/logic/network.go index c2bb6bb1..6d763f74 100644 --- a/logic/network.go +++ b/logic/network.go @@ -5,10 +5,17 @@ import ( "os/exec" "strings" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" ) +// CheckNetworkExists - checks i a network exists for this netmaker instance +func CheckNetworkExists(network string) bool { + var _, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, network) + return err == nil +} + // GetLocalIP - gets the local ip func GetLocalIP(node models.Node) string { diff --git a/logic/util.go b/logic/util.go index 3ea52a93..c733ccf6 100644 --- a/logic/util.go +++ b/logic/util.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "log" + "math/rand" "strconv" "strings" "time" @@ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin return peers, err } +// RandomString - returns a random string in a charset +func RandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + func setPeerInfo(node models.Node) models.Node { var peer models.Node peer.RelayAddrs = node.RelayAddrs @@ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node { func Log(message string, loglevel int) { log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) - if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 { + if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 { log.Println("[netmaker] " + message) } } diff --git a/logic/wireguard.go b/logic/wireguard.go index 45191fbc..cb318f06 100644 --- a/logic/wireguard.go +++ b/logic/wireguard.go @@ -62,10 +62,11 @@ func setWGConfig(node models.Node, network string, peerupdate bool) error { var iface string iface = node.Interface err = setServerPeers(iface, node.PersistentKeepalive, peers) + Log("updated peers on server "+node.Name, 2) } else { err = initWireguard(&node, privkey, peers, hasGateway, gateways) + Log("finished setting wg config on server "+node.Name, 3) } - Log("finished setting wg config on server "+node.Name, 1) return err } diff --git a/main.go b/main.go index d98cc6a0..999adacf 100644 --- a/main.go +++ b/main.go @@ -13,11 +13,13 @@ import ( "sync" "time" + "github.com/gravitl/netmaker/auth" controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/dnslogic" "github.com/gravitl/netmaker/functions" nodepb "github.com/gravitl/netmaker/grpc" + "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" @@ -35,20 +37,29 @@ func main() { func initialize() { // Client Mode Prereq Check var err error + if err = database.InitializeDatabase(); err != nil { - log.Println("Error connecting to database.") + logic.Log("Error connecting to database", 0) log.Fatal(err) } - log.Println("database successfully connected.") + logic.Log("database successfully connected", 0) + + var authProvider = auth.InitializeAuthProvider() + if authProvider != "" { + logic.Log("OAuth provider, "+authProvider+", initialized", 0) + } else { + logic.Log("no OAuth provider found or not configured, continuing without OAuth", 0) + } + if servercfg.IsClientMode() != "off" { output, err := ncutils.RunCmd("id -u", true) if err != nil { - log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") + logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0) log.Fatal(output, err) } uid, err := strconv.Atoi(string(output[:len(output)-1])) if err != nil { - log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.") + logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0) log.Fatal(err) } if uid != 0 { @@ -74,7 +85,7 @@ func startControllers() { if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { err := servercfg.SetHost() if err != nil { - log.Println("Unable to Set host. Exiting...") + logic.Log("Unable to Set host. Exiting...", 0) log.Fatal(err) } } @@ -90,7 +101,7 @@ func startControllers() { if servercfg.IsDNSMode() { err := dnslogic.SetDNS() if err != nil { - log.Println("error occurred initializing DNS:", err) + logic.Log("error occurred initializing DNS: "+err.Error(), 0) } } //Run Rest Server @@ -98,7 +109,7 @@ func startControllers() { if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { err := servercfg.SetHost() if err != nil { - log.Println("Unable to Set host. Exiting...") + logic.Log("Unable to Set host. Exiting...", 0) log.Fatal(err) } } @@ -106,11 +117,11 @@ func startControllers() { controller.HandleRESTRequests(&waitnetwork) } if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() { - log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.") + logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0) } waitnetwork.Wait() - log.Println("[netmaker] exiting") + logic.Log("exiting", 0) } func runClient(wg *sync.WaitGroup) { @@ -139,7 +150,7 @@ func runGRPC(wg *sync.WaitGroup) { listener, err := net.Listen("tcp", ":"+grpcport) // Handle errors if any if err != nil { - log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err) + log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err) } s := grpc.NewServer( @@ -157,7 +168,7 @@ func runGRPC(wg *sync.WaitGroup) { log.Fatalf("Failed to serve: %v", err) } }() - log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)") + logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0) // Right way to stop the server using a SHUTDOWN HOOK // Create a channel to receive OS signals @@ -172,11 +183,11 @@ func runGRPC(wg *sync.WaitGroup) { <-c // After receiving CTRL+C Properly stop the server - log.Println("Stopping the Agent server...") + logic.Log("Stopping the Agent server...", 0) s.Stop() listener.Close() - log.Println("Agent server closed..") - log.Println("Closed DB connection.") + logic.Log("Agent server closed..", 0) + logic.Log("Closed DB connection.", 0) } func authServerUnaryInterceptor() grpc.ServerOption { diff --git a/netclient/functions/join.go b/netclient/functions/join.go index c24de8c1..fe12433c 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "log" - "os/exec" nodepb "github.com/gravitl/netmaker/grpc" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/auth" @@ -18,6 +16,8 @@ import ( "github.com/gravitl/netmaker/netclient/wireguard" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "log" + "os/exec" ) // JoinNetwork - helps a client join a network @@ -84,8 +84,8 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error { if ncutils.IsLinux() { _, err := exec.LookPath("resolvectl") if err != nil { - ncutils.PrintLog("resolvectl not present",2) - ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management",2) + ncutils.PrintLog("resolvectl not present", 2) + ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management", 2) cfg.Node.DNSOn = "no" } } diff --git a/netclient/ncutils/netclientutils.go b/netclient/ncutils/netclientutils.go index 830253e2..793d93f7 100644 --- a/netclient/ncutils/netclientutils.go +++ b/netclient/ncutils/netclientutils.go @@ -155,9 +155,9 @@ func parsePeers(keepalive int32, peers []wgtypes.PeerConfig) (string, error) { if keepalive <= 0 { keepalive = 20 } - + for _, peer := range peers { - endpointString := "" + endpointString := "" if peer.Endpoint != nil && peer.Endpoint.String() != "" { endpointString += "Endpoint = " + peer.Endpoint.String() } diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 29676f23..8e9d8b4b 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strconv" + "strings" "github.com/gravitl/netmaker/config" ) @@ -65,8 +66,26 @@ func GetServerConfig() config.ServerConfig { cfg.Database = GetDB() cfg.Platform = GetPlatform() cfg.Version = GetVersion() + + // == auth config == + var authInfo = GetAuthProviderInfo() + cfg.AuthProvider = authInfo[0] + cfg.ClientID = authInfo[1] + cfg.ClientSecret = authInfo[2] + cfg.FrontendURL = GetFrontendURL() + return cfg } +func GetFrontendURL() string { + var frontend = "" + if os.Getenv("FRONTEND_URL") != "" { + frontend = os.Getenv("FRONTEND_URL") + } else if config.Config.Server.FrontendURL != "" { + frontend = config.Config.Server.FrontendURL + } + return frontend +} + func GetAPIConnString() string { conn := "" if os.Getenv("SERVER_API_CONN_STRING") != "" { @@ -77,7 +96,7 @@ func GetAPIConnString() string { return conn } func GetVersion() string { - version := "0.8.4" + version := "0.8.5" if config.Config.Server.Version != "" { version = config.Config.Server.Version } @@ -398,6 +417,25 @@ func GetServerCheckinInterval() int64 { return t } +// GetAuthProviderInfo = gets the oauth provider info +func GetAuthProviderInfo() []string { + var authProvider = "" + if os.Getenv("AUTH_PROVIDER") != "" && os.Getenv("CLIENT_ID") != "" && os.Getenv("CLIENT_SECRET") != "" { + authProvider = strings.ToLower(os.Getenv("AUTH_PROVIDER")) + if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" { + return []string{authProvider, os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET")} + } else { + authProvider = "" + } + } else if config.Config.Server.AuthProvider != "" && config.Config.Server.ClientID != "" && config.Config.Server.ClientSecret != "" { + authProvider = strings.ToLower(config.Config.Server.AuthProvider) + if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" { + return []string{authProvider, config.Config.Server.ClientID, config.Config.Server.ClientSecret} + } + } + return []string{"", "", ""} +} + // GetMacAddr - get's mac address func getMacAddr() string { ifas, err := net.Interfaces() diff --git a/servercfg/sqlconf.go b/servercfg/sqlconf.go index 0cef15ca..91dc4736 100644 --- a/servercfg/sqlconf.go +++ b/servercfg/sqlconf.go @@ -1,8 +1,8 @@ package servercfg import ( - "os" "github.com/gravitl/netmaker/config" + "os" "strconv" )