mirror of
				https://github.com/gravitl/netmaker.git
				synced 2025-10-31 00:17:13 +08:00 
			
		
		
		
	began oauth implementation
This commit is contained in:
		
							parent
							
								
									8a54f50676
								
							
						
					
					
						commit
						4e4e8b3ab5
					
				
					 11 changed files with 330 additions and 243 deletions
				
			
		
							
								
								
									
										31
									
								
								auth/auth.go
									
										
									
									
									
								
							
							
						
						
									
										31
									
								
								auth/auth.go
									
										
									
									
									
								
							|  | @ -1,6 +1,7 @@ | ||||||
| package auth | package auth | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gravitl/netmaker/servercfg" | 	"github.com/gravitl/netmaker/servercfg" | ||||||
|  | @ -13,14 +14,20 @@ const ( | ||||||
| 	get_user_info          = "getuserinfo" | 	get_user_info          = "getuserinfo" | ||||||
| 	handle_callback        = "handlecallback" | 	handle_callback        = "handlecallback" | ||||||
| 	handle_login           = "handlelogin" | 	handle_login           = "handlelogin" | ||||||
| 	oauth_state_string     = "netmaker-oauth-state" |  | ||||||
| 	google_provider_name   = "google" | 	google_provider_name   = "google" | ||||||
| 	azure_ad_provider_name = "azure-ad" | 	azure_ad_provider_name = "azure-ad" | ||||||
| 	github_provider_name   = "github" | 	github_provider_name   = "github" | ||||||
|  | 	verify_user            = "verifyuser" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login | ||||||
| var auth_provider *oauth2.Config | var auth_provider *oauth2.Config | ||||||
| 
 | 
 | ||||||
|  | type OauthUser struct { | ||||||
|  | 	Email       string `json:"email" bson:"email"` | ||||||
|  | 	AccessToken string `json:"accesstoken" bson:"accesstoken"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func getCurrentAuthFunctions() map[string]interface{} { | func getCurrentAuthFunctions() map[string]interface{} { | ||||||
| 	var authInfo = servercfg.GetAuthProviderInfo() | 	var authInfo = servercfg.GetAuthProviderInfo() | ||||||
| 	var authProvider = authInfo[0] | 	var authProvider = authInfo[0] | ||||||
|  | @ -37,14 +44,14 @@ func getCurrentAuthFunctions() map[string]interface{} { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // InitializeAuthProvider - initializes the auth provider if any is present | // InitializeAuthProvider - initializes the auth provider if any is present | ||||||
| func InitializeAuthProvider() bool { | func InitializeAuthProvider() string { | ||||||
| 	var functions = getCurrentAuthFunctions() | 	var functions = getCurrentAuthFunctions() | ||||||
| 	if functions == nil { | 	if functions == nil { | ||||||
| 		return false | 		return "" | ||||||
| 	} | 	} | ||||||
| 	var authInfo = servercfg.GetAuthProviderInfo() | 	var authInfo = servercfg.GetAuthProviderInfo() | ||||||
| 	functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString(), authInfo[1], authInfo[2]) | 	functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2]) | ||||||
| 	return auth_provider != nil | 	return authInfo[0] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // HandleAuthCallback - handles oauth callback | // HandleAuthCallback - handles oauth callback | ||||||
|  | @ -64,3 +71,17 @@ func HandleAuthLogin(w http.ResponseWriter, r *http.Request) { | ||||||
| 	} | 	} | ||||||
| 	functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r) | 	functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // VerifyUserToken - checks if oauth2 token is valid | ||||||
|  | func VerifyUserToken(accessToken string) bool { | ||||||
|  | 	var token = &oauth2.Token{} | ||||||
|  | 	var err = json.Unmarshal([]byte(accessToken), token) | ||||||
|  | 	if err != nil || !token.Valid() { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	var functions = getCurrentAuthFunctions() | ||||||
|  | 	if functions == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return functions[verify_user].(func(*oauth2.Token) bool)(token) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,10 +1,13 @@ | ||||||
| package auth | package auth | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/gravitl/netmaker/logic" | ||||||
|  | 	"github.com/gravitl/netmaker/servercfg" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| 	"golang.org/x/oauth2/google" | 	"golang.org/x/oauth2/google" | ||||||
| ) | ) | ||||||
|  | @ -14,6 +17,7 @@ var google_functions = map[string]interface{}{ | ||||||
| 	get_user_info:   getUserInfo, | 	get_user_info:   getUserInfo, | ||||||
| 	handle_callback: handleGoogleCallback, | 	handle_callback: handleGoogleCallback, | ||||||
| 	handle_login:    handleGoogleLogin, | 	handle_login:    handleGoogleLogin, | ||||||
|  | 	verify_user:     verifyGoogleUser, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // == handle google authentication here == | // == handle google authentication here == | ||||||
|  | @ -29,6 +33,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { | func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 	oauth_state_string = logic.RandomString(16) | ||||||
| 	url := auth_provider.AuthCodeURL(oauth_state_string) | 	url := auth_provider.AuthCodeURL(oauth_state_string) | ||||||
| 	http.Redirect(w, r, url, http.StatusTemporaryRedirect) | 	http.Redirect(w, r, url, http.StatusTemporaryRedirect) | ||||||
| } | } | ||||||
|  | @ -38,20 +43,26 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { | ||||||
| 	var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code")) | 	var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		fmt.Println(err.Error()) | 		fmt.Println(err.Error()) | ||||||
| 		http.Redirect(w, r, "/api/oauth/error", http.StatusTemporaryRedirect) | 		http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	fmt.Fprintf(w, "Content: %s\n", content) | 	logic.Log("completed google oauth sigin in for "+content.Email, 0) | ||||||
|  | 	http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth="+content.AccessToken+"&email="+content.Email, http.StatusPermanentRedirect) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getUserInfo(state string, code string) ([]byte, error) { | func getUserInfo(state string, code string) (*OauthUser, error) { | ||||||
| 	if state != oauth_state_string { | 	if state != oauth_state_string { | ||||||
| 		return nil, fmt.Errorf("invalid oauth state") | 		return nil, fmt.Errorf("invalid oauth state") | ||||||
| 	} | 	} | ||||||
| 	token, err := auth_provider.Exchange(oauth2.NoContext, code) | 	var token, err = auth_provider.Exchange(oauth2.NoContext, code) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("code exchange failed: %s", err.Error()) | 		return nil, fmt.Errorf("code exchange failed: %s", err.Error()) | ||||||
| 	} | 	} | ||||||
|  | 	var data []byte | ||||||
|  | 	data, err = json.Marshal(token) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to convert token to json: %s", err.Error()) | ||||||
|  | 	} | ||||||
| 	response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) | 	response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed getting user info: %s", err.Error()) | 		return nil, fmt.Errorf("failed getting user info: %s", err.Error()) | ||||||
|  | @ -61,5 +72,19 @@ func getUserInfo(state string, code string) ([]byte, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed reading response body: %s", err.Error()) | 		return nil, fmt.Errorf("failed reading response body: %s", err.Error()) | ||||||
| 	} | 	} | ||||||
| 	return contents, nil | 	var userInfo = &OauthUser{} | ||||||
|  | 	if err = json.Unmarshal(contents, userInfo); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error()) | ||||||
|  | 	} | ||||||
|  | 	userInfo.AccessToken = string(data) | ||||||
|  | 	return userInfo, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func verifyGoogleUser(token *oauth2.Token) bool { | ||||||
|  | 	if token.Valid() { | ||||||
|  | 		var err error | ||||||
|  | 		_, err = http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) | ||||||
|  | 		return err == nil | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -65,6 +65,7 @@ type ServerConfig struct { | ||||||
| 	AuthProvider          string `yaml:"authprovider"` | 	AuthProvider          string `yaml:"authprovider"` | ||||||
| 	ClientID              string `yaml:"clientid"` | 	ClientID              string `yaml:"clientid"` | ||||||
| 	ClientSecret          string `yaml:"clientsecret"` | 	ClientSecret          string `yaml:"clientsecret"` | ||||||
|  | 	FrontendURL           string `yaml:"frontendurl"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Generic SQL Config | // Generic SQL Config | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/gorilla/handlers" | 	"github.com/gorilla/handlers" | ||||||
| 	"github.com/gorilla/mux" | 	"github.com/gorilla/mux" | ||||||
|  | 	"github.com/gravitl/netmaker/logic" | ||||||
| 	"github.com/gravitl/netmaker/servercfg" | 	"github.com/gravitl/netmaker/servercfg" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) { | ||||||
| 			log.Println(err) | 			log.Println(err) | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 	logic.Log("REST Server successfully started on port "+port+" (REST)", 0) | ||||||
| 	log.Println("REST Server successfully started on port " + port + " (REST)") |  | ||||||
| 	c := make(chan os.Signal) | 	c := make(chan os.Signal) | ||||||
| 
 | 
 | ||||||
| 	// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) | 	// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) | ||||||
|  | @ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) { | ||||||
| 	<-c | 	<-c | ||||||
| 
 | 
 | ||||||
| 	// After receiving CTRL+C Properly stop the server | 	// After receiving CTRL+C Properly stop the server | ||||||
| 	log.Println("Stopping the REST server...") | 	logic.Log("Stopping the REST server...", 0) | ||||||
| 	srv.Shutdown(context.TODO()) | 	srv.Shutdown(context.TODO()) | ||||||
| 	log.Println("REST Server closed.") | 	logic.Log("REST Server closed.", 0) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -5,7 +5,6 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"math/rand" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { | ||||||
| 		"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1) | 		"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1) | ||||||
| 	returnSuccessResponse(w, r, params["clientid"]+" deleted.") | 	returnSuccessResponse(w, r, params["clientid"]+" deleted.") | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // StringWithCharset - returns a random string in a charset |  | ||||||
| func StringWithCharset(length int, charset string) string { |  | ||||||
| 	b := make([]byte, length) |  | ||||||
| 	for i := range b { |  | ||||||
| 		b[i] = charset[seededRand.Intn(len(charset))] |  | ||||||
| 	} |  | ||||||
| 	return string(b) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" |  | ||||||
| 
 |  | ||||||
| var seededRand *rand.Rand = rand.New( |  | ||||||
| 	rand.NewSource(time.Now().UnixNano())) |  | ||||||
|  |  | ||||||
|  | @ -3,14 +3,14 @@ package controller | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-playground/validator/v10" |  | ||||||
| 	"github.com/gorilla/mux" | 	"github.com/gorilla/mux" | ||||||
|  | 	"github.com/gravitl/netmaker/auth" | ||||||
| 	"github.com/gravitl/netmaker/database" | 	"github.com/gravitl/netmaker/database" | ||||||
| 	"github.com/gravitl/netmaker/functions" | 	"github.com/gravitl/netmaker/functions" | ||||||
|  | 	"github.com/gravitl/netmaker/logic" | ||||||
| 	"github.com/gravitl/netmaker/models" | 	"github.com/gravitl/netmaker/models" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  | @ -26,6 +26,13 @@ func userHandlers(r *mux.Router) { | ||||||
| 	r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE") | 	r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE") | ||||||
| 	r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET") | 	r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET") | ||||||
| 	r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET") | 	r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET") | ||||||
|  | 	r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET") | ||||||
|  | 	r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET") | ||||||
|  | 	r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func throwOauthError(response http.ResponseWriter, request *http.Request) { | ||||||
|  | 	returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Node authenticates using its password and retrieves a JWT for authorization. | // Node authenticates using its password and retrieves a JWT for authorization. | ||||||
|  | @ -181,37 +188,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // HasAdmin - checks if server has an admin |  | ||||||
| func HasAdmin() (bool, error) { |  | ||||||
| 
 |  | ||||||
| 	collection, err := database.FetchRecords(database.USERS_TABLE_NAME) |  | ||||||
| 	if err != nil { |  | ||||||
| 		if database.IsEmptyRecord(err) { |  | ||||||
| 			return false, nil |  | ||||||
| 		} else { |  | ||||||
| 			return true, err |  | ||||||
| 
 |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	for _, value := range collection { // filter for isadmin true |  | ||||||
| 		var user models.User |  | ||||||
| 		err = json.Unmarshal([]byte(value), &user) |  | ||||||
| 		if err != nil { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if user.IsAdmin { |  | ||||||
| 			return true, nil |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return false, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func hasAdmin(w http.ResponseWriter, r *http.Request) { | func hasAdmin(w http.ResponseWriter, r *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 
 | 
 | ||||||
| 	hasadmin, err := HasAdmin() | 	hasadmin, err := logic.HasAdmin() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "internal")) | 		returnErrorResponse(w, r, formatError(err, "internal")) | ||||||
| 		return | 		return | ||||||
|  | @ -221,20 +202,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) { | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetUser - gets a user |  | ||||||
| func GetUser(username string) (models.ReturnUser, error) { |  | ||||||
| 
 |  | ||||||
| 	var user models.ReturnUser |  | ||||||
| 	record, err := database.FetchRecord(database.USERS_TABLE_NAME, username) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	if err = json.Unmarshal([]byte(record), &user); err != nil { |  | ||||||
| 		return models.ReturnUser{}, err |  | ||||||
| 	} |  | ||||||
| 	return user, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // GetUserInternal - gets an internal user | // GetUserInternal - gets an internal user | ||||||
| func GetUserInternal(username string) (models.User, error) { | func GetUserInternal(username string) (models.User, error) { | ||||||
| 
 | 
 | ||||||
|  | @ -249,30 +216,6 @@ func GetUserInternal(username string) (models.User, error) { | ||||||
| 	return user, err | 	return user, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetUsers - gets users |  | ||||||
| func GetUsers() ([]models.ReturnUser, error) { |  | ||||||
| 
 |  | ||||||
| 	var users []models.ReturnUser |  | ||||||
| 
 |  | ||||||
| 	collection, err := database.FetchRecords(database.USERS_TABLE_NAME) |  | ||||||
| 
 |  | ||||||
| 	if err != nil { |  | ||||||
| 		return users, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, value := range collection { |  | ||||||
| 
 |  | ||||||
| 		var user models.ReturnUser |  | ||||||
| 		err = json.Unmarshal([]byte(value), &user) |  | ||||||
| 		if err != nil { |  | ||||||
| 			continue // get users |  | ||||||
| 		} |  | ||||||
| 		users = append(users, user) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return users, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Get an individual node. Nothin fancy here folks. | // Get an individual node. Nothin fancy here folks. | ||||||
| func getUser(w http.ResponseWriter, r *http.Request) { | func getUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// set header. | 	// set header. | ||||||
|  | @ -280,7 +223,7 @@ func getUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	var params = mux.Vars(r) | 	var params = mux.Vars(r) | ||||||
| 	usernameFetched := params["username"] | 	usernameFetched := params["username"] | ||||||
| 	user, err := GetUser(usernameFetched) | 	user, err := logic.GetUser(usernameFetched) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "internal")) | 		returnErrorResponse(w, r, formatError(err, "internal")) | ||||||
|  | @ -295,7 +238,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// set header. | 	// set header. | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 
 | 
 | ||||||
| 	users, err := GetUsers() | 	users, err := logic.GetUsers() | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "internal")) | 		returnErrorResponse(w, r, formatError(err, "internal")) | ||||||
|  | @ -306,42 +249,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) { | ||||||
| 	json.NewEncoder(w).Encode(users) | 	json.NewEncoder(w).Encode(users) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CreateUser - creates a user |  | ||||||
| func CreateUser(user models.User) (models.User, error) { |  | ||||||
| 	// check if user exists |  | ||||||
| 	if _, err := GetUser(user.UserName); err == nil { |  | ||||||
| 		return models.User{}, errors.New("user exists") |  | ||||||
| 	} |  | ||||||
| 	err := ValidateUser("create", user) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// encrypt that password so we never see it again |  | ||||||
| 	hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	// set password to encrypted password |  | ||||||
| 	user.Password = string(hash) |  | ||||||
| 
 |  | ||||||
| 	tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) |  | ||||||
| 
 |  | ||||||
| 	if tokenString == "" { |  | ||||||
| 		// returnErrorResponse(w, r, errorResponse) |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// connect db |  | ||||||
| 	data, err := json.Marshal(&user) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) |  | ||||||
| 
 |  | ||||||
| 	return user, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func createAdmin(w http.ResponseWriter, r *http.Request) { | func createAdmin(w http.ResponseWriter, r *http.Request) { | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 
 | 
 | ||||||
|  | @ -349,7 +256,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// get node from body of request | 	// get node from body of request | ||||||
| 	_ = json.NewDecoder(r.Body).Decode(&admin) | 	_ = json.NewDecoder(r.Body).Decode(&admin) | ||||||
| 
 | 
 | ||||||
| 	admin, err := CreateAdmin(admin) | 	admin, err := logic.CreateAdmin(admin) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "badrequest")) | 		returnErrorResponse(w, r, formatError(err, "badrequest")) | ||||||
|  | @ -359,18 +266,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) { | ||||||
| 	json.NewEncoder(w).Encode(admin) | 	json.NewEncoder(w).Encode(admin) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CreateAdmin(admin models.User) (models.User, error) { |  | ||||||
| 	hasadmin, err := HasAdmin() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 	if hasadmin { |  | ||||||
| 		return models.User{}, errors.New("admin user already exists") |  | ||||||
| 	} |  | ||||||
| 	admin.IsAdmin = true |  | ||||||
| 	return CreateUser(admin) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func createUser(w http.ResponseWriter, r *http.Request) { | func createUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 
 | 
 | ||||||
|  | @ -378,7 +273,7 @@ func createUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// get node from body of request | 	// get node from body of request | ||||||
| 	_ = json.NewDecoder(r.Body).Decode(&user) | 	_ = json.NewDecoder(r.Body).Decode(&user) | ||||||
| 
 | 
 | ||||||
| 	user, err := CreateUser(user) | 	user, err := logic.CreateUser(user) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "badrequest")) | 		returnErrorResponse(w, r, formatError(err, "badrequest")) | ||||||
|  | @ -388,52 +283,6 @@ func createUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	json.NewEncoder(w).Encode(user) | 	json.NewEncoder(w).Encode(user) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UpdateUser - updates a given user |  | ||||||
| func UpdateUser(userchange models.User, user models.User) (models.User, error) { |  | ||||||
| 	//check if user exists |  | ||||||
| 	if _, err := GetUser(user.UserName); err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err := ValidateUser("update", userchange) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	queryUser := user.UserName |  | ||||||
| 
 |  | ||||||
| 	if userchange.UserName != "" { |  | ||||||
| 		user.UserName = userchange.UserName |  | ||||||
| 	} |  | ||||||
| 	if len(userchange.Networks) > 0 { |  | ||||||
| 		user.Networks = userchange.Networks |  | ||||||
| 	} |  | ||||||
| 	if userchange.Password != "" { |  | ||||||
| 		// encrypt that password so we never see it again |  | ||||||
| 		hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) |  | ||||||
| 
 |  | ||||||
| 		if err != nil { |  | ||||||
| 			return userchange, err |  | ||||||
| 		} |  | ||||||
| 		// set password to encrypted password |  | ||||||
| 		userchange.Password = string(hash) |  | ||||||
| 
 |  | ||||||
| 		user.Password = userchange.Password |  | ||||||
| 	} |  | ||||||
| 	if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 	data, err := json.Marshal(&user) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 	if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil { |  | ||||||
| 		return models.User{}, err |  | ||||||
| 	} |  | ||||||
| 	functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1) |  | ||||||
| 	return user, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func updateUser(w http.ResponseWriter, r *http.Request) { | func updateUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 	var params = mux.Vars(r) | 	var params = mux.Vars(r) | ||||||
|  | @ -453,7 +302,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	userchange.Networks = nil | 	userchange.Networks = nil | ||||||
| 	user, err = UpdateUser(userchange, user) | 	user, err = logic.UpdateUser(userchange, user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "badrequest")) | 		returnErrorResponse(w, r, formatError(err, "badrequest")) | ||||||
| 		return | 		return | ||||||
|  | @ -480,7 +329,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "internal")) | 		returnErrorResponse(w, r, formatError(err, "internal")) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	user, err = UpdateUser(userchange, user) | 	user, err = logic.UpdateUser(userchange, user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "badrequest")) | 		returnErrorResponse(w, r, formatError(err, "badrequest")) | ||||||
| 		return | 		return | ||||||
|  | @ -489,20 +338,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) { | ||||||
| 	json.NewEncoder(w).Encode(user) | 	json.NewEncoder(w).Encode(user) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // DeleteUser - deletes a given user |  | ||||||
| func DeleteUser(user string) (bool, error) { |  | ||||||
| 
 |  | ||||||
| 	if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 { |  | ||||||
| 		return false, errors.New("user does not exist") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err := database.DeleteRecord(database.USERS_TABLE_NAME, user) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 	return true, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func deleteUser(w http.ResponseWriter, r *http.Request) { | func deleteUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	// Set header | 	// Set header | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
|  | @ -511,7 +346,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	var params = mux.Vars(r) | 	var params = mux.Vars(r) | ||||||
| 
 | 
 | ||||||
| 	username := params["username"] | 	username := params["username"] | ||||||
| 	success, err := DeleteUser(username) | 	success, err := logic.DeleteUser(username) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		returnErrorResponse(w, r, formatError(err, "internal")) | 		returnErrorResponse(w, r, formatError(err, "internal")) | ||||||
|  | @ -524,17 +359,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { | ||||||
| 	functions.PrintUserLog(username, "was deleted", 1) | 	functions.PrintUserLog(username, "was deleted", 1) | ||||||
| 	json.NewEncoder(w).Encode(params["username"] + " deleted.") | 	json.NewEncoder(w).Encode(params["username"] + " deleted.") | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // ValidateUser - validates a user model |  | ||||||
| func ValidateUser(operation string, user models.User) error { |  | ||||||
| 
 |  | ||||||
| 	v := validator.New() |  | ||||||
| 	err := v.Struct(user) |  | ||||||
| 
 |  | ||||||
| 	if err != nil { |  | ||||||
| 		for _, e := range err.(validator.ValidationErrors) { |  | ||||||
| 			fmt.Println(e) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -87,11 +87,11 @@ func getCurrentDB() map[string]interface{} { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func InitializeDatabase() error { | func InitializeDatabase() error { | ||||||
| 	log.Println("connecting to", servercfg.GetDB()) | 	log.Println("[netmaker] connecting to", servercfg.GetDB()) | ||||||
| 	tperiod := time.Now().Add(10 * time.Second) | 	tperiod := time.Now().Add(10 * time.Second) | ||||||
| 	for { | 	for { | ||||||
| 		if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil { | 		if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil { | ||||||
| 			log.Println("unable to connect to db, retrying . . .") | 			log.Println("[netmaker] unable to connect to db, retrying . . .") | ||||||
| 			if time.Now().After(tperiod) { | 			if time.Now().After(tperiod) { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
							
								
								
									
										199
									
								
								logic/auth.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								logic/auth.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,199 @@ | ||||||
|  | package logic | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	"github.com/go-playground/validator/v10" | ||||||
|  | 	"github.com/gravitl/netmaker/database" | ||||||
|  | 	"github.com/gravitl/netmaker/functions" | ||||||
|  | 	"github.com/gravitl/netmaker/models" | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // HasAdmin - checks if server has an admin | ||||||
|  | func HasAdmin() (bool, error) { | ||||||
|  | 
 | ||||||
|  | 	collection, err := database.FetchRecords(database.USERS_TABLE_NAME) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if database.IsEmptyRecord(err) { | ||||||
|  | 			return false, nil | ||||||
|  | 		} else { | ||||||
|  | 			return true, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for _, value := range collection { // filter for isadmin true | ||||||
|  | 		var user models.User | ||||||
|  | 		err = json.Unmarshal([]byte(value), &user) | ||||||
|  | 		if err != nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if user.IsAdmin { | ||||||
|  | 			return true, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetUser - gets a user | ||||||
|  | func GetUser(username string) (models.ReturnUser, error) { | ||||||
|  | 
 | ||||||
|  | 	var user models.ReturnUser | ||||||
|  | 	record, err := database.FetchRecord(database.USERS_TABLE_NAME, username) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	if err = json.Unmarshal([]byte(record), &user); err != nil { | ||||||
|  | 		return models.ReturnUser{}, err | ||||||
|  | 	} | ||||||
|  | 	return user, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetUsers - gets users | ||||||
|  | func GetUsers() ([]models.ReturnUser, error) { | ||||||
|  | 
 | ||||||
|  | 	var users []models.ReturnUser | ||||||
|  | 
 | ||||||
|  | 	collection, err := database.FetchRecords(database.USERS_TABLE_NAME) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return users, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, value := range collection { | ||||||
|  | 
 | ||||||
|  | 		var user models.ReturnUser | ||||||
|  | 		err = json.Unmarshal([]byte(value), &user) | ||||||
|  | 		if err != nil { | ||||||
|  | 			continue // get users | ||||||
|  | 		} | ||||||
|  | 		users = append(users, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return users, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CreateUser - creates a user | ||||||
|  | func CreateUser(user models.User) (models.User, error) { | ||||||
|  | 	// check if user exists | ||||||
|  | 	if _, err := GetUser(user.UserName); err == nil { | ||||||
|  | 		return models.User{}, errors.New("user exists") | ||||||
|  | 	} | ||||||
|  | 	var err = ValidateUser(user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// encrypt that password so we never see it again | ||||||
|  | 	hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	// set password to encrypted password | ||||||
|  | 	user.Password = string(hash) | ||||||
|  | 
 | ||||||
|  | 	tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin) | ||||||
|  | 
 | ||||||
|  | 	if tokenString == "" { | ||||||
|  | 		// returnErrorResponse(w, r, errorResponse) | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// connect db | ||||||
|  | 	data, err := json.Marshal(&user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME) | ||||||
|  | 
 | ||||||
|  | 	return user, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CreateAdmin - creates an admin user | ||||||
|  | func CreateAdmin(admin models.User) (models.User, error) { | ||||||
|  | 	hasadmin, err := HasAdmin() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 	if hasadmin { | ||||||
|  | 		return models.User{}, errors.New("admin user already exists") | ||||||
|  | 	} | ||||||
|  | 	admin.IsAdmin = true | ||||||
|  | 	return CreateUser(admin) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // UpdateUser - updates a given user | ||||||
|  | func UpdateUser(userchange models.User, user models.User) (models.User, error) { | ||||||
|  | 	//check if user exists | ||||||
|  | 	if _, err := GetUser(user.UserName); err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err := ValidateUser(userchange) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	queryUser := user.UserName | ||||||
|  | 
 | ||||||
|  | 	if userchange.UserName != "" { | ||||||
|  | 		user.UserName = userchange.UserName | ||||||
|  | 	} | ||||||
|  | 	if len(userchange.Networks) > 0 { | ||||||
|  | 		user.Networks = userchange.Networks | ||||||
|  | 	} | ||||||
|  | 	if userchange.Password != "" { | ||||||
|  | 		// encrypt that password so we never see it again | ||||||
|  | 		hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5) | ||||||
|  | 
 | ||||||
|  | 		if err != nil { | ||||||
|  | 			return userchange, err | ||||||
|  | 		} | ||||||
|  | 		// set password to encrypted password | ||||||
|  | 		userchange.Password = string(hash) | ||||||
|  | 
 | ||||||
|  | 		user.Password = userchange.Password | ||||||
|  | 	} | ||||||
|  | 	if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 	data, err := json.Marshal(&user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 	if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil { | ||||||
|  | 		return models.User{}, err | ||||||
|  | 	} | ||||||
|  | 	functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1) | ||||||
|  | 	return user, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ValidateUser - validates a user model | ||||||
|  | func ValidateUser(user models.User) error { | ||||||
|  | 
 | ||||||
|  | 	v := validator.New() | ||||||
|  | 	err := v.Struct(user) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		for _, e := range err.(validator.ValidationErrors) { | ||||||
|  | 			Log(e.Error(), 2) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DeleteUser - deletes a given user | ||||||
|  | func DeleteUser(user string) (bool, error) { | ||||||
|  | 
 | ||||||
|  | 	if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 { | ||||||
|  | 		return false, errors.New("user does not exist") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err := database.DeleteRecord(database.USERS_TABLE_NAME, user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"log" | 	"log" | ||||||
|  | 	"math/rand" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin | ||||||
| 	return peers, err | 	return peers, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RandomString - returns a random string in a charset | ||||||
|  | func RandomString(length int) string { | ||||||
|  | 	const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" | ||||||
|  | 
 | ||||||
|  | 	var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) | ||||||
|  | 
 | ||||||
|  | 	b := make([]byte, length) | ||||||
|  | 	for i := range b { | ||||||
|  | 		b[i] = charset[seededRand.Intn(len(charset))] | ||||||
|  | 	} | ||||||
|  | 	return string(b) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func setPeerInfo(node models.Node) models.Node { | func setPeerInfo(node models.Node) models.Node { | ||||||
| 	var peer models.Node | 	var peer models.Node | ||||||
| 	peer.RelayAddrs = node.RelayAddrs | 	peer.RelayAddrs = node.RelayAddrs | ||||||
|  | @ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node { | ||||||
| 
 | 
 | ||||||
| func Log(message string, loglevel int) { | func Log(message string, loglevel int) { | ||||||
| 	log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) | 	log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) | ||||||
| 	if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 { | 	if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 { | ||||||
| 		log.Println("[netmaker] " + message) | 		log.Println("[netmaker] " + message) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										37
									
								
								main.go
									
										
									
									
									
								
							
							
						
						
									
										37
									
								
								main.go
									
										
									
									
									
								
							|  | @ -13,11 +13,13 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/gravitl/netmaker/auth" | ||||||
| 	controller "github.com/gravitl/netmaker/controllers" | 	controller "github.com/gravitl/netmaker/controllers" | ||||||
| 	"github.com/gravitl/netmaker/database" | 	"github.com/gravitl/netmaker/database" | ||||||
| 	"github.com/gravitl/netmaker/dnslogic" | 	"github.com/gravitl/netmaker/dnslogic" | ||||||
| 	"github.com/gravitl/netmaker/functions" | 	"github.com/gravitl/netmaker/functions" | ||||||
| 	nodepb "github.com/gravitl/netmaker/grpc" | 	nodepb "github.com/gravitl/netmaker/grpc" | ||||||
|  | 	"github.com/gravitl/netmaker/logic" | ||||||
| 	"github.com/gravitl/netmaker/models" | 	"github.com/gravitl/netmaker/models" | ||||||
| 	"github.com/gravitl/netmaker/netclient/ncutils" | 	"github.com/gravitl/netmaker/netclient/ncutils" | ||||||
| 	"github.com/gravitl/netmaker/servercfg" | 	"github.com/gravitl/netmaker/servercfg" | ||||||
|  | @ -35,20 +37,27 @@ func main() { | ||||||
| 
 | 
 | ||||||
| func initialize() { // Client Mode Prereq Check | func initialize() { // Client Mode Prereq Check | ||||||
| 	var err error | 	var err error | ||||||
|  | 
 | ||||||
| 	if err = database.InitializeDatabase(); err != nil { | 	if err = database.InitializeDatabase(); err != nil { | ||||||
| 		log.Println("Error connecting to database.") | 		logic.Log("Error connecting to database", 0) | ||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	log.Println("database successfully connected.") | 	logic.Log("database successfully connected", 0) | ||||||
|  | 
 | ||||||
|  | 	var authProvider = auth.InitializeAuthProvider() | ||||||
|  | 	if authProvider != "" { | ||||||
|  | 		logic.Log("OAuth provider, "+authProvider+", initialized", 0) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if servercfg.IsClientMode() != "off" { | 	if servercfg.IsClientMode() != "off" { | ||||||
| 		output, err := ncutils.RunCmd("id -u", true) | 		output, err := ncutils.RunCmd("id -u", true) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") | 			logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0) | ||||||
| 			log.Fatal(output, err) | 			log.Fatal(output, err) | ||||||
| 		} | 		} | ||||||
| 		uid, err := strconv.Atoi(string(output[:len(output)-1])) | 		uid, err := strconv.Atoi(string(output[:len(output)-1])) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.") | 			logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0) | ||||||
| 			log.Fatal(err) | 			log.Fatal(err) | ||||||
| 		} | 		} | ||||||
| 		if uid != 0 { | 		if uid != 0 { | ||||||
|  | @ -74,7 +83,7 @@ func startControllers() { | ||||||
| 		if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { | 		if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { | ||||||
| 			err := servercfg.SetHost() | 			err := servercfg.SetHost() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Println("Unable to Set host. Exiting...") | 				logic.Log("Unable to Set host. Exiting...", 0) | ||||||
| 				log.Fatal(err) | 				log.Fatal(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | @ -90,7 +99,7 @@ func startControllers() { | ||||||
| 	if servercfg.IsDNSMode() { | 	if servercfg.IsDNSMode() { | ||||||
| 		err := dnslogic.SetDNS() | 		err := dnslogic.SetDNS() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Println("error occurred initializing DNS:", err) | 			logic.Log("error occurred initializing DNS: "+err.Error(), 0) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	//Run Rest Server | 	//Run Rest Server | ||||||
|  | @ -98,7 +107,7 @@ func startControllers() { | ||||||
| 		if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { | 		if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { | ||||||
| 			err := servercfg.SetHost() | 			err := servercfg.SetHost() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Println("Unable to Set host. Exiting...") | 				logic.Log("Unable to Set host. Exiting...", 0) | ||||||
| 				log.Fatal(err) | 				log.Fatal(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | @ -106,11 +115,11 @@ func startControllers() { | ||||||
| 		controller.HandleRESTRequests(&waitnetwork) | 		controller.HandleRESTRequests(&waitnetwork) | ||||||
| 	} | 	} | ||||||
| 	if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() { | 	if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() { | ||||||
| 		log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.") | 		logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	waitnetwork.Wait() | 	waitnetwork.Wait() | ||||||
| 	log.Println("[netmaker] exiting") | 	logic.Log("exiting", 0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func runClient(wg *sync.WaitGroup) { | func runClient(wg *sync.WaitGroup) { | ||||||
|  | @ -139,7 +148,7 @@ func runGRPC(wg *sync.WaitGroup) { | ||||||
| 	listener, err := net.Listen("tcp", ":"+grpcport) | 	listener, err := net.Listen("tcp", ":"+grpcport) | ||||||
| 	// Handle errors if any | 	// Handle errors if any | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err) | 		log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	s := grpc.NewServer( | 	s := grpc.NewServer( | ||||||
|  | @ -157,7 +166,7 @@ func runGRPC(wg *sync.WaitGroup) { | ||||||
| 			log.Fatalf("Failed to serve: %v", err) | 			log.Fatalf("Failed to serve: %v", err) | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 	log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)") | 	logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0) | ||||||
| 
 | 
 | ||||||
| 	// Right way to stop the server using a SHUTDOWN HOOK | 	// Right way to stop the server using a SHUTDOWN HOOK | ||||||
| 	// Create a channel to receive OS signals | 	// Create a channel to receive OS signals | ||||||
|  | @ -172,11 +181,11 @@ func runGRPC(wg *sync.WaitGroup) { | ||||||
| 	<-c | 	<-c | ||||||
| 
 | 
 | ||||||
| 	// After receiving CTRL+C Properly stop the server | 	// After receiving CTRL+C Properly stop the server | ||||||
| 	log.Println("Stopping the Agent server...") | 	logic.Log("Stopping the Agent server...", 0) | ||||||
| 	s.Stop() | 	s.Stop() | ||||||
| 	listener.Close() | 	listener.Close() | ||||||
| 	log.Println("Agent server closed..") | 	logic.Log("Agent server closed..", 0) | ||||||
| 	log.Println("Closed DB connection.") | 	logic.Log("Closed DB connection.", 0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func authServerUnaryInterceptor() grpc.ServerOption { | func authServerUnaryInterceptor() grpc.ServerOption { | ||||||
|  |  | ||||||
|  | @ -72,8 +72,20 @@ func GetServerConfig() config.ServerConfig { | ||||||
| 	cfg.AuthProvider = authInfo[0] | 	cfg.AuthProvider = authInfo[0] | ||||||
| 	cfg.ClientID = authInfo[1] | 	cfg.ClientID = authInfo[1] | ||||||
| 	cfg.ClientSecret = authInfo[2] | 	cfg.ClientSecret = authInfo[2] | ||||||
|  | 	cfg.FrontendURL = GetFrontendURL() | ||||||
|  | 
 | ||||||
| 	return cfg | 	return cfg | ||||||
| } | } | ||||||
|  | func GetFrontendURL() string { | ||||||
|  | 	var frontend = "" | ||||||
|  | 	if os.Getenv("FRONTEND_URL") != "" { | ||||||
|  | 		frontend = os.Getenv("FRONTEND_URL") | ||||||
|  | 	} else if config.Config.Server.FrontendURL != "" { | ||||||
|  | 		frontend = config.Config.Server.FrontendURL | ||||||
|  | 	} | ||||||
|  | 	return frontend | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func GetAPIConnString() string { | func GetAPIConnString() string { | ||||||
| 	conn := "" | 	conn := "" | ||||||
| 	if os.Getenv("SERVER_API_CONN_STRING") != "" { | 	if os.Getenv("SERVER_API_CONN_STRING") != "" { | ||||||
|  | @ -84,7 +96,7 @@ func GetAPIConnString() string { | ||||||
| 	return conn | 	return conn | ||||||
| } | } | ||||||
| func GetVersion() string { | func GetVersion() string { | ||||||
| 	version := "0.8.4" | 	version := "0.8.5" | ||||||
| 	if config.Config.Server.Version != "" { | 	if config.Config.Server.Version != "" { | ||||||
| 		version = config.Config.Server.Version | 		version = config.Config.Server.Version | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue