Merge pull request #386 from gravitl/feature_v0.8.5_oauth

Feature v0.8.5 oauth
This commit is contained in:
dcarns 2021-10-25 18:56:30 -04:00 committed by GitHub
commit a22c029a51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1046 additions and 352 deletions

View file

@ -8,7 +8,7 @@
<p align="center">
<a href="https://github.com/gravitl/netmaker/releases">
<img src="https://img.shields.io/badge/Version-0.8.4-informational?style=flat-square" />
<img src="https://img.shields.io/badge/Version-0.8.5-informational?style=flat-square" />
</a>
<a href="https://discord.gg/zRb9Vfhk8A">
<img src="https://img.shields.io/badge/community-discord-informational" />

157
auth/auth.go Normal file
View file

@ -0,0 +1,157 @@
package auth
import (
"encoding/base64"
"encoding/json"
"net/http"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
)
// == consts ==
const (
init_provider = "initprovider"
get_user_info = "getuserinfo"
handle_callback = "handlecallback"
handle_login = "handlelogin"
google_provider_name = "google"
azure_ad_provider_name = "azure-ad"
github_provider_name = "github"
verify_user = "verifyuser"
auth_key = "netmaker_auth"
)
var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
var auth_provider *oauth2.Config
func getCurrentAuthFunctions() map[string]interface{} {
var authInfo = servercfg.GetAuthProviderInfo()
var authProvider = authInfo[0]
switch authProvider {
case google_provider_name:
return google_functions
case azure_ad_provider_name:
return azure_ad_functions
case github_provider_name:
return github_functions
default:
return nil
}
}
// InitializeAuthProvider - initializes the auth provider if any is present
func InitializeAuthProvider() string {
var functions = getCurrentAuthFunctions()
if functions == nil {
return ""
}
var _, err = fetchPassValue(logic.RandomString(64))
if err != nil {
logic.Log(err.Error(), 0)
return ""
}
var currentFrontendURL = servercfg.GetFrontendURL()
if currentFrontendURL == "" {
return ""
}
var authInfo = servercfg.GetAuthProviderInfo()
functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2])
return authInfo[0]
}
// HandleAuthCallback - handles oauth callback
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
var functions = getCurrentAuthFunctions()
if functions == nil {
return
}
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
}
// HandleAuthLogin - handles oauth login
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
var functions = getCurrentAuthFunctions()
if functions == nil {
return
}
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
}
// IsOauthUser - returns
func IsOauthUser(user *models.User) error {
var currentValue, err = fetchPassValue("")
if err != nil {
return err
}
var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue))
return bCryptErr
}
// == private methods ==
func addUser(email string) error {
var hasAdmin, err = logic.HasAdmin()
if err != nil {
logic.Log("error checking for existence of admin user during OAuth login for "+email+", user not added", 1)
return err
} // generate random password to adapt to current model
var newPass, fetchErr = fetchPassValue("")
if fetchErr != nil {
return fetchErr
}
var newUser = models.User{
UserName: email,
Password: newPass,
}
if !hasAdmin { // must be first attempt, create an admin
if newUser, err = logic.CreateAdmin(newUser); err != nil {
logic.Log("error creating admin from user, "+email+", user not added", 1)
} else {
logic.Log("admin created from user, "+email+", was first user added", 0)
}
} else { // otherwise add to db as admin..?
// TODO: add ability to add users with preemptive permissions
newUser.IsAdmin = false
if newUser, err = logic.CreateUser(newUser); err != nil {
logic.Log("error creating user, "+email+", user not added", 1)
} else {
logic.Log("user created from, "+email+"", 0)
}
}
return nil
}
func fetchPassValue(newValue string) (string, error) {
type valueHolder struct {
Value string `json:"value" bson:"value"`
}
var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue))
var newValueHolder = &valueHolder{
Value: b64NewValue,
}
var data, marshalErr = json.Marshal(newValueHolder)
if marshalErr != nil {
return "", marshalErr
}
var currentValue, err = logic.FetchAuthSecret(auth_key, string(data))
if err != nil {
return "", err
}
var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder)
if unmarshErr != nil {
return "", unmarshErr
}
var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
if b64Err != nil {
logic.Log("could not decode pass", 0)
return "", nil
}
return string(b64CurrentValue), nil
}

126
auth/azure-ad.go Normal file
View file

@ -0,0 +1,126 @@
package auth
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
)
var azure_ad_functions = map[string]interface{}{
init_provider: initAzureAD,
get_user_info: getAzureUserInfo,
handle_callback: handleAzureCallback,
handle_login: handleAzureLogin,
verify_user: verifyAzureUser,
}
type azureOauthUser struct {
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle azure ad authentication here ==
func initAzureAD(redirectURL string, clientID string, clientSecret string) {
auth_provider = &oauth2.Config{
RedirectURL: redirectURL,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{"User.Read"},
Endpoint: microsoft.AzureADEndpoint(os.Getenv("AZURE_TENANT")),
}
}
func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
oauth_state_string = logic.RandomString(16)
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return
} else if auth_provider == nil {
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
return
}
var url = auth_provider.AuthCodeURL(oauth_state_string)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}
func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getAzureUserInfo(r.FormValue("state"), r.FormValue("code"))
if err != nil {
logic.Log("error when getting user info from azure: "+err.Error(), 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return
}
_, err = logic.GetUser(content.UserPrincipalName)
if err != nil { // user must not exists, so try to make one
if err = addUser(content.UserPrincipalName); err != nil {
return
}
}
var newPass, fetchErr = fetchPassValue("")
if fetchErr != nil {
return
}
// send a netmaker jwt token
var authRequest = models.UserAuthParams{
UserName: content.UserPrincipalName,
Password: newPass,
}
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
if jwtErr != nil {
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
return
}
logic.Log("completed azure OAuth sigin in for "+content.UserPrincipalName, 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect)
}
func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
if state != oauth_state_string {
return nil, fmt.Errorf("invalid oauth state")
}
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
if err != nil {
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
}
var data []byte
data, err = json.Marshal(token)
if err != nil {
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
}
var httpReq, reqErr = http.NewRequest("GET", "https://graph.microsoft.com/v1.0/me", nil)
if reqErr != nil {
return nil, fmt.Errorf("failed to create request to GitHub")
}
httpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
response, err := http.DefaultClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
}
var userInfo = &azureOauthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
}
userInfo.AccessToken = string(data)
return userInfo, nil
}
func verifyAzureUser(token *oauth2.Token) bool {
return token.Valid()
}

129
auth/github.go Normal file
View file

@ -0,0 +1,129 @@
package auth
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
)
var github_functions = map[string]interface{}{
init_provider: initGithub,
get_user_info: getGithubUserInfo,
handle_callback: handleGithubCallback,
handle_login: handleGithubLogin,
verify_user: verifyGithubUser,
}
type githubOauthUser struct {
Login string `json:"login" bson:"login"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle github authentication here ==
func initGithub(redirectURL string, clientID string, clientSecret string) {
auth_provider = &oauth2.Config{
RedirectURL: redirectURL,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{},
Endpoint: github.Endpoint,
}
}
func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
oauth_state_string = logic.RandomString(16)
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"?error=callback-error", http.StatusTemporaryRedirect)
return
} else if auth_provider == nil {
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
return
}
var url = auth_provider.AuthCodeURL(oauth_state_string)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}
func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getGithubUserInfo(r.URL.Query().Get("state"), r.URL.Query().Get("code"))
if err != nil {
logic.Log("error when getting user info from github: "+err.Error(), 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return
}
_, err = logic.GetUser(content.Login)
if err != nil { // user must not exist, so try to make one
if err = addUser(content.Login); err != nil {
return
}
}
var newPass, fetchErr = fetchPassValue("")
if fetchErr != nil {
return
}
// send a netmaker jwt token
var authRequest = models.UserAuthParams{
UserName: content.Login,
Password: newPass,
}
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
if jwtErr != nil {
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
return
}
logic.Log("completed github OAuth sigin in for "+content.Login, 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect)
}
func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
if state != oauth_state_string {
return nil, fmt.Errorf("invalid OAuth state")
}
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
if err != nil {
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
}
if !token.Valid() {
return nil, fmt.Errorf("GitHub code exchange yielded invalid token")
}
var data []byte
data, err = json.Marshal(token)
if err != nil {
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
}
var httpClient = &http.Client{}
var httpReq, reqErr = http.NewRequest("GET", "https://api.github.com/user", nil)
if reqErr != nil {
return nil, fmt.Errorf("failed to create request to GitHub")
}
httpReq.Header.Set("Authorization", "token "+token.AccessToken)
response, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
}
var userInfo = &githubOauthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
}
userInfo.AccessToken = string(data)
return userInfo, nil
}
func verifyGithubUser(token *oauth2.Token) bool {
return token.Valid()
}

120
auth/google.go Normal file
View file

@ -0,0 +1,120 @@
package auth
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
var google_functions = map[string]interface{}{
init_provider: initGoogle,
get_user_info: getGoogleUserInfo,
handle_callback: handleGoogleCallback,
handle_login: handleGoogleLogin,
verify_user: verifyGoogleUser,
}
type googleOauthUser struct {
Email string `json:"email" bson:"email"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
// == handle google authentication here ==
func initGoogle(redirectURL string, clientID string, clientSecret string) {
auth_provider = &oauth2.Config{
RedirectURL: redirectURL,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"},
Endpoint: google.Endpoint,
}
}
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
oauth_state_string = logic.RandomString(16)
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return
} else if auth_provider == nil {
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
return
}
var url = auth_provider.AuthCodeURL(oauth_state_string)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
}
func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getGoogleUserInfo(r.FormValue("state"), r.FormValue("code"))
if err != nil {
logic.Log("error when getting user info from google: "+err.Error(), 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return
}
_, err = logic.GetUser(content.Email)
if err != nil { // user must not exists, so try to make one
if err = addUser(content.Email); err != nil {
return
}
}
var newPass, fetchErr = fetchPassValue("")
if fetchErr != nil {
return
}
// send a netmaker jwt token
var authRequest = models.UserAuthParams{
UserName: content.Email,
Password: newPass,
}
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
if jwtErr != nil {
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
return
}
logic.Log("completed google OAuth sigin in for "+content.Email, 1)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
}
func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
if state != oauth_state_string {
return nil, fmt.Errorf("invalid OAuth state")
}
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
if err != nil {
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
}
var data []byte
data, err = json.Marshal(token)
if err != nil {
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
}
response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
if err != nil {
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
}
var userInfo = &googleOauthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
}
userInfo.AccessToken = string(data)
return userInfo, nil
}
func verifyGoogleUser(token *oauth2.Token) bool {
return token.Valid()
}

View file

@ -62,9 +62,12 @@ type ServerConfig struct {
DefaultNodeLimit int32 `yaml:"defaultnodelimit"`
Verbosity int32 `yaml:"verbosity"`
ServerCheckinInterval int64 `yaml:"servercheckininterval"`
AuthProvider string `yaml:"authprovider"`
ClientID string `yaml:"clientid"`
ClientSecret string `yaml:"clientsecret"`
FrontendURL string `yaml:"frontendurl"`
}
// Generic SQL Config
type SQLConfig struct {
Host string `yaml:"host"`

View file

@ -10,6 +10,7 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg"
)
@ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
log.Println(err)
}
}()
log.Println("REST Server successfully started on port " + port + " (REST)")
logic.Log("REST Server successfully started on port "+port+" (REST)", 0)
c := make(chan os.Signal)
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
@ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
<-c
// After receiving CTRL+C Properly stop the server
log.Println("Stopping the REST server...")
logic.Log("Stopping the REST server...", 0)
srv.Shutdown(context.TODO())
log.Println("REST Server closed.")
logic.Log("REST Server closed.", 0)
}

View file

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"time"
@ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1)
returnSuccessResponse(w, r, params["clientid"]+" deleted.")
}
// StringWithCharset - returns a random string in a charset
func StringWithCharset(length int, charset string) string {
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))

View file

@ -2,9 +2,9 @@ package controller
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
@ -48,7 +48,7 @@ func returnErrorResponse(response http.ResponseWriter, request *http.Request, er
if err != nil {
panic(err)
}
fmt.Println(errorMessage)
logic.Log("processed request error: "+errorMessage.Message, 1)
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(errorMessage.Code)
response.Write(jsonResponse)

View file

@ -7,12 +7,12 @@ import (
"net/http"
"strings"
"github.com/go-playground/validator/v10"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"golang.org/x/crypto/bcrypt"
)
func userHandlers(r *mux.Router) {
@ -21,11 +21,19 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST")
r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST")
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(updateUser))).Methods("PUT")
r.HandleFunc("/api/users/networks/{username}", authorizeUserAdm(http.HandlerFunc(updateUserNetworks))).Methods("PUT")
r.HandleFunc("/api/users/{username}/adm", authorizeUserAdm(http.HandlerFunc(updateUserAdm))).Methods("PUT")
r.HandleFunc("/api/users/{username}", authorizeUserAdm(http.HandlerFunc(createUser))).Methods("POST")
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE")
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET")
r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET")
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET")
}
func throwOauthError(response http.ResponseWriter, request *http.Request) {
returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized"))
}
// Node authenticates using its password and retrieves a JWT for authorization.
@ -46,7 +54,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
return
}
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
if err != nil {
returnErrorResponse(response, request, formatError(err, "badrequest"))
return
@ -79,35 +87,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
response.Write(successJSONResponse)
}
// VerifyAuthRequest - verifies an auth request
func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
var result models.User
if authRequest.UserName == "" {
return "", errors.New("username can't be empty")
} else if authRequest.Password == "" {
return "", errors.New("password can't be empty")
}
//Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
if err != nil {
return "", errors.New("incorrect credentials")
}
if err = json.Unmarshal([]byte(record), &result); err != nil {
return "", errors.New("incorrect credentials")
}
// compare password from request to stored password in database
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
// TODO: Consider a way of hashing the password client side before sending, or using certificates
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
return "", errors.New("incorrect credentials")
}
//Create a new JWT for the node
tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin)
return tokenString, nil
}
// The middleware for most requests to the API
// They all pass through here first
// This will validate the JWT (or check for master token)
@ -181,37 +160,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error {
return nil
}
// HasAdmin - checks if server has an admin
func HasAdmin() (bool, error) {
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
} else {
return true, err
}
}
for _, value := range collection { // filter for isadmin true
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
if user.IsAdmin {
return true, nil
}
}
return false, err
}
func hasAdmin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
hasadmin, err := HasAdmin()
hasadmin, err := logic.HasAdmin()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
@ -221,20 +174,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
}
// GetUser - gets a user
func GetUser(username string) (models.ReturnUser, error) {
var user models.ReturnUser
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
return user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return models.ReturnUser{}, err
}
return user, err
}
// GetUserInternal - gets an internal user
func GetUserInternal(username string) (models.User, error) {
@ -249,30 +188,6 @@ func GetUserInternal(username string) (models.User, error) {
return user, err
}
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
var users []models.ReturnUser
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
for _, value := range collection {
var user models.ReturnUser
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
users = append(users, user)
}
return users, err
}
// Get an individual node. Nothin fancy here folks.
func getUser(w http.ResponseWriter, r *http.Request) {
// set header.
@ -280,7 +195,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
usernameFetched := params["username"]
user, err := GetUser(usernameFetched)
user, err := logic.GetUser(usernameFetched)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
@ -295,7 +210,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
// set header.
w.Header().Set("Content-Type", "application/json")
users, err := GetUsers()
users, err := logic.GetUsers()
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
@ -306,42 +221,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(users)
}
// CreateUser - creates a user
func CreateUser(user models.User) (models.User, error) {
// check if user exists
if _, err := GetUser(user.UserName); err == nil {
return models.User{}, errors.New("user exists")
}
err := ValidateUser("create", user)
if err != nil {
return models.User{}, err
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
if err != nil {
return user, err
}
// set password to encrypted password
user.Password = string(hash)
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
if tokenString == "" {
// returnErrorResponse(w, r, errorResponse)
return user, err
}
// connect db
data, err := json.Marshal(&user)
if err != nil {
return user, err
}
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
return user, err
}
func createAdmin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -349,7 +228,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
// get node from body of request
_ = json.NewDecoder(r.Body).Decode(&admin)
admin, err := CreateAdmin(admin)
admin, err := logic.CreateAdmin(admin)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
@ -359,18 +238,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(admin)
}
func CreateAdmin(admin models.User) (models.User, error) {
hasadmin, err := HasAdmin()
if err != nil {
return models.User{}, err
}
if hasadmin {
return models.User{}, errors.New("admin user already exists")
}
admin.IsAdmin = true
return CreateUser(admin)
}
func createUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -378,7 +245,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
// get node from body of request
_ = json.NewDecoder(r.Body).Decode(&user)
user, err := CreateUser(user)
user, err := logic.CreateUser(user)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
@ -388,53 +255,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(user)
}
// UpdateUser - updates a given user
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
//check if user exists
if _, err := GetUser(user.UserName); err != nil {
return models.User{}, err
}
err := ValidateUser("update", userchange)
if err != nil {
return models.User{}, err
}
queryUser := user.UserName
if userchange.UserName != "" {
user.UserName = userchange.UserName
}
if len(userchange.Networks) > 0 {
user.Networks = userchange.Networks
}
if userchange.Password != "" {
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
if err != nil {
return userchange, err
}
// set password to encrypted password
userchange.Password = string(hash)
user.Password = userchange.Password
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return models.User{}, err
}
data, err := json.Marshal(&user)
if err != nil {
return models.User{}, err
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return models.User{}, err
}
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
return user, nil
}
func updateUser(w http.ResponseWriter, r *http.Request) {
func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
var user models.User
@ -452,8 +273,40 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
err = logic.UpdateUserNetworks(userchange.Networks, userchange.IsAdmin, &user)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
}
functions.PrintUserLog(username, "status was updated", 1)
json.NewEncoder(w).Encode(user)
}
func updateUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
var user models.User
// start here
username := params["username"]
user, err := GetUserInternal(username)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
if auth.IsOauthUser(&user) == nil {
returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user %s", username), "forbidden"))
return
}
var userchange models.User
// we decode our body request params
err = json.NewDecoder(r.Body).Decode(&userchange)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
userchange.Networks = nil
user, err = UpdateUser(userchange, user)
user, err = logic.UpdateUser(userchange, user)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
@ -473,6 +326,10 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
if auth.IsOauthUser(&user) != nil {
returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user"), "forbidden"))
return
}
var userchange models.User
// we decode our body request params
err = json.NewDecoder(r.Body).Decode(&userchange)
@ -480,7 +337,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal"))
return
}
user, err = UpdateUser(userchange, user)
user, err = logic.UpdateUser(userchange, user)
if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest"))
return
@ -489,20 +346,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(user)
}
// DeleteUser - deletes a given user
func DeleteUser(user string) (bool, error) {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return false, errors.New("user does not exist")
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
if err != nil {
return false, err
}
return true, nil
}
func deleteUser(w http.ResponseWriter, r *http.Request) {
// Set header
w.Header().Set("Content-Type", "application/json")
@ -511,7 +354,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
username := params["username"]
success, err := DeleteUser(username)
success, err := logic.DeleteUser(username)
if err != nil {
returnErrorResponse(w, r, formatError(err, "internal"))
@ -524,17 +367,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
functions.PrintUserLog(username, "was deleted", 1)
json.NewEncoder(w).Encode(params["username"] + " deleted.")
}
// ValidateUser - validates a user model
func ValidateUser(operation string, user models.User) error {
v := validator.New()
err := v.Struct(user)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
fmt.Println(e)
}
}
return err
}

View file

@ -4,52 +4,53 @@ import (
"testing"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/stretchr/testify/assert"
)
func deleteAllUsers() {
users, _ := GetUsers()
users, _ := logic.GetUsers()
for _, user := range users {
DeleteUser(user.UserName)
logic.DeleteUser(user.UserName)
}
}
func TestHasAdmin(t *testing.T) {
//delete all current users
database.InitializeDatabase()
users, _ := GetUsers()
users, _ := logic.GetUsers()
for _, user := range users {
success, err := DeleteUser(user.UserName)
success, err := logic.DeleteUser(user.UserName)
assert.Nil(t, err)
assert.True(t, success)
}
t.Run("NoUser", func(t *testing.T) {
found, err := HasAdmin()
found, err := logic.HasAdmin()
assert.Nil(t, err)
assert.False(t, found)
})
t.Run("No admin user", func(t *testing.T) {
var user = models.User{"noadmin", "password", nil, false}
_, err := CreateUser(user)
_, err := logic.CreateUser(user)
assert.Nil(t, err)
found, err := HasAdmin()
found, err := logic.HasAdmin()
assert.Nil(t, err)
assert.False(t, found)
})
t.Run("admin user", func(t *testing.T) {
var user = models.User{"admin", "password", nil, true}
_, err := CreateUser(user)
_, err := logic.CreateUser(user)
assert.Nil(t, err)
found, err := HasAdmin()
found, err := logic.HasAdmin()
assert.Nil(t, err)
assert.True(t, found)
})
t.Run("multiple admins", func(t *testing.T) {
var user = models.User{"admin1", "password", nil, true}
_, err := CreateUser(user)
_, err := logic.CreateUser(user)
assert.Nil(t, err)
found, err := HasAdmin()
found, err := logic.HasAdmin()
assert.Nil(t, err)
assert.True(t, found)
})
@ -60,12 +61,12 @@ func TestCreateUser(t *testing.T) {
deleteAllUsers()
user := models.User{"admin", "password", nil, true}
t.Run("NoUser", func(t *testing.T) {
admin, err := CreateUser(user)
admin, err := logic.CreateUser(user)
assert.Nil(t, err)
assert.Equal(t, user.UserName, admin.UserName)
})
t.Run("UserExists", func(t *testing.T) {
_, err := CreateUser(user)
_, err := logic.CreateUser(user)
assert.NotNil(t, err)
assert.EqualError(t, err, "user exists")
})
@ -78,14 +79,14 @@ func TestCreateAdmin(t *testing.T) {
t.Run("NoAdmin", func(t *testing.T) {
user.UserName = "admin"
user.Password = "password"
admin, err := CreateAdmin(user)
admin, err := logic.CreateAdmin(user)
assert.Nil(t, err)
assert.Equal(t, user.UserName, admin.UserName)
})
t.Run("AdminExists", func(t *testing.T) {
user.UserName = "admin2"
user.Password = "password1"
admin, err := CreateAdmin(user)
admin, err := logic.CreateAdmin(user)
assert.EqualError(t, err, "admin user already exists")
assert.Equal(t, admin, models.User{})
})
@ -95,14 +96,14 @@ func TestDeleteUser(t *testing.T) {
database.InitializeDatabase()
deleteAllUsers()
t.Run("NonExistent User", func(t *testing.T) {
deleted, err := DeleteUser("admin")
deleted, err := logic.DeleteUser("admin")
assert.EqualError(t, err, "user does not exist")
assert.False(t, deleted)
})
t.Run("Existing User", func(t *testing.T) {
user := models.User{"admin", "password", nil, true}
CreateUser(user)
deleted, err := DeleteUser("admin")
logic.CreateUser(user)
deleted, err := logic.DeleteUser("admin")
assert.Nil(t, err)
assert.True(t, deleted)
})
@ -114,44 +115,44 @@ func TestValidateUser(t *testing.T) {
t.Run("Valid Create", func(t *testing.T) {
user.UserName = "admin"
user.Password = "validpass"
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.Nil(t, err)
})
t.Run("Valid Update", func(t *testing.T) {
user.UserName = "admin"
user.Password = "password"
err := ValidateUser("update", user)
err := logic.ValidateUser(user)
assert.Nil(t, err)
})
t.Run("Invalid UserName", func(t *testing.T) {
t.Skip()
user.UserName = "*invalid"
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.Error(t, err)
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
})
t.Run("Short UserName", func(t *testing.T) {
t.Skip()
user.UserName = "1"
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.NotNil(t, err)
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
})
t.Run("Empty UserName", func(t *testing.T) {
t.Skip()
user.UserName = ""
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.EqualError(t, err, "some string")
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
})
t.Run("EmptyPassword", func(t *testing.T) {
user.Password = ""
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'required' tag")
})
t.Run("ShortPassword", func(t *testing.T) {
user.Password = "123"
err := ValidateUser("create", user)
err := logic.ValidateUser(user)
assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'min' tag")
})
}
@ -160,14 +161,14 @@ func TestGetUser(t *testing.T) {
database.InitializeDatabase()
deleteAllUsers()
t.Run("NonExistantUser", func(t *testing.T) {
admin, err := GetUser("admin")
admin, err := logic.GetUser("admin")
assert.EqualError(t, err, "could not find any records")
assert.Equal(t, "", admin.UserName)
})
t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true}
CreateUser(user)
admin, err := GetUser("admin")
logic.CreateUser(user)
admin, err := logic.GetUser("admin")
assert.Nil(t, err)
assert.Equal(t, user.UserName, admin.UserName)
})
@ -183,7 +184,7 @@ func TestGetUserInternal(t *testing.T) {
})
t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true}
CreateUser(user)
logic.CreateUser(user)
admin, err := GetUserInternal("admin")
assert.Nil(t, err)
assert.Equal(t, user.UserName, admin.UserName)
@ -194,21 +195,21 @@ func TestGetUsers(t *testing.T) {
database.InitializeDatabase()
deleteAllUsers()
t.Run("NonExistantUser", func(t *testing.T) {
admin, err := GetUsers()
admin, err := logic.GetUsers()
assert.EqualError(t, err, "could not find any records")
assert.Equal(t, []models.ReturnUser(nil), admin)
})
t.Run("UserExisits", func(t *testing.T) {
user := models.User{"admin", "password", nil, true}
CreateUser(user)
admins, err := GetUsers()
logic.CreateUser(user)
admins, err := logic.GetUsers()
assert.Nil(t, err)
assert.Equal(t, user.UserName, admins[0].UserName)
})
t.Run("MulipleUsers", func(t *testing.T) {
user := models.User{"user", "password", nil, true}
CreateUser(user)
admins, err := GetUsers()
logic.CreateUser(user)
admins, err := logic.GetUsers()
assert.Nil(t, err)
for _, u := range admins {
if u.UserName == "admin" {
@ -227,14 +228,14 @@ func TestUpdateUser(t *testing.T) {
user := models.User{"admin", "password", nil, true}
newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true}
t.Run("NonExistantUser", func(t *testing.T) {
admin, err := UpdateUser(newuser, user)
admin, err := logic.UpdateUser(newuser, user)
assert.EqualError(t, err, "could not find any records")
assert.Equal(t, "", admin.UserName)
})
t.Run("UserExists", func(t *testing.T) {
CreateUser(user)
admin, err := UpdateUser(newuser, user)
logic.CreateUser(user)
admin, err := logic.UpdateUser(newuser, user)
assert.Nil(t, err)
assert.Equal(t, newuser.UserName, admin.UserName)
})
@ -271,43 +272,43 @@ func TestVerifyAuthRequest(t *testing.T) {
t.Run("EmptyUserName", func(t *testing.T) {
authRequest.UserName = ""
authRequest.Password = "Password"
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Equal(t, "", jwt)
assert.EqualError(t, err, "username can't be empty")
})
t.Run("EmptyPassword", func(t *testing.T) {
authRequest.UserName = "admin"
authRequest.Password = ""
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Equal(t, "", jwt)
assert.EqualError(t, err, "password can't be empty")
})
t.Run("NonExistantUser", func(t *testing.T) {
authRequest.UserName = "admin"
authRequest.Password = "password"
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Equal(t, "", jwt)
assert.EqualError(t, err, "incorrect credentials")
})
t.Run("Non-Admin", func(t *testing.T) {
user := models.User{"nonadmin", "somepass", nil, false}
CreateUser(user)
logic.CreateUser(user)
authRequest := models.UserAuthParams{"nonadmin", "somepass"}
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.NotNil(t, jwt)
assert.Nil(t, err)
})
t.Run("WrongPassword", func(t *testing.T) {
user := models.User{"admin", "password", nil, false}
CreateUser(user)
logic.CreateUser(user)
authRequest := models.UserAuthParams{"admin", "badpass"}
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Equal(t, "", jwt)
assert.EqualError(t, err, "incorrect credentials")
})
t.Run("Success", func(t *testing.T) {
authRequest := models.UserAuthParams{"admin", "password"}
jwt, err := VerifyAuthRequest(authRequest)
jwt, err := logic.VerifyAuthRequest(authRequest)
assert.Nil(t, err)
assert.NotNil(t, jwt)
})

View file

@ -39,6 +39,9 @@ const SERVERCONF_TABLE_NAME = "serverconf"
// DATABASE_FILENAME - database file name
const DATABASE_FILENAME = "netmaker.db"
// GENERATED_TABLE_NAME - stores server generated k/v
const GENERATED_TABLE_NAME = "generated"
// == ERROR CONSTS ==
// NO_RECORD - no singular result found
@ -87,11 +90,11 @@ func getCurrentDB() map[string]interface{} {
}
func InitializeDatabase() error {
log.Println("connecting to", servercfg.GetDB())
log.Println("[netmaker] connecting to", servercfg.GetDB())
tperiod := time.Now().Add(10 * time.Second)
for {
if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil {
log.Println("unable to connect to db, retrying . . .")
log.Println("[netmaker] unable to connect to db, retrying . . .")
if time.Now().After(tperiod) {
return err
}
@ -114,6 +117,7 @@ func createTables() {
createTable(INT_CLIENTS_TABLE_NAME)
createTable(PEERS_TABLE_NAME)
createTable(SERVERCONF_TABLE_NAME)
createTable(GENERATED_TABLE_NAME)
}
func createTable(tableName string) error {

1
go.mod
View file

@ -17,6 +17,7 @@ require (
github.com/urfave/cli/v2 v2.3.0
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 // indirect
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect
golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e // indirect
golang.org/x/text v0.3.7-0.20210524175448-3115f89c4b99 // indirect
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 // indirect

268
logic/auth.go Normal file
View file

@ -0,0 +1,268 @@
package logic
import (
"encoding/json"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/models"
"golang.org/x/crypto/bcrypt"
)
// HasAdmin - checks if server has an admin
func HasAdmin() (bool, error) {
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
} else {
return true, err
}
}
for _, value := range collection { // filter for isadmin true
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
if user.IsAdmin {
return true, nil
}
}
return false, err
}
// GetUser - gets a user
func GetUser(username string) (models.ReturnUser, error) {
var user models.ReturnUser
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
return user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return models.ReturnUser{}, err
}
return user, err
}
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
var users []models.ReturnUser
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
for _, value := range collection {
var user models.ReturnUser
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
users = append(users, user)
}
return users, err
}
// CreateUser - creates a user
func CreateUser(user models.User) (models.User, error) {
// check if user exists
if _, err := GetUser(user.UserName); err == nil {
return models.User{}, errors.New("user exists")
}
var err = ValidateUser(user)
if err != nil {
return models.User{}, err
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
if err != nil {
return user, err
}
// set password to encrypted password
user.Password = string(hash)
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
if tokenString == "" {
// returnErrorResponse(w, r, errorResponse)
return user, err
}
// connect db
data, err := json.Marshal(&user)
if err != nil {
return user, err
}
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
return user, err
}
// CreateAdmin - creates an admin user
func CreateAdmin(admin models.User) (models.User, error) {
hasadmin, err := HasAdmin()
if err != nil {
return models.User{}, err
}
if hasadmin {
return models.User{}, errors.New("admin user already exists")
}
admin.IsAdmin = true
return CreateUser(admin)
}
// VerifyAuthRequest - verifies an auth request
func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
var result models.User
if authRequest.UserName == "" {
return "", errors.New("username can't be empty")
} else if authRequest.Password == "" {
return "", errors.New("password can't be empty")
}
//Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
if err != nil {
return "", errors.New("incorrect credentials")
}
if err = json.Unmarshal([]byte(record), &result); err != nil {
return "", errors.New("incorrect credentials")
}
// compare password from request to stored password in database
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
// TODO: Consider a way of hashing the password client side before sending, or using certificates
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
return "", errors.New("incorrect credentials")
}
//Create a new JWT for the node
tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin)
return tokenString, nil
}
// UpdateUserNetworks - updates the networks of a given user
func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.User) error {
// check if user exists
if returnedUser, err := GetUser(currentUser.UserName); err != nil {
return err
} else if returnedUser.IsAdmin {
return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName)
}
if isadmin {
currentUser.IsAdmin = true
currentUser.Networks = nil
} else {
currentUser.Networks = newNetworks
}
data, err := json.Marshal(currentUser)
if err != nil {
return err
}
if err = database.Insert(currentUser.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return err
}
return nil
}
// UpdateUser - updates a given user
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
//check if user exists
if _, err := GetUser(user.UserName); err != nil {
return models.User{}, err
}
err := ValidateUser(userchange)
if err != nil {
return models.User{}, err
}
queryUser := user.UserName
if userchange.UserName != "" {
user.UserName = userchange.UserName
}
if len(userchange.Networks) > 0 {
user.Networks = userchange.Networks
}
if userchange.Password != "" {
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
if err != nil {
return userchange, err
}
// set password to encrypted password
userchange.Password = string(hash)
user.Password = userchange.Password
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return models.User{}, err
}
data, err := json.Marshal(&user)
if err != nil {
return models.User{}, err
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return models.User{}, err
}
Log("updated user "+queryUser, 1)
return user, nil
}
// ValidateUser - validates a user model
func ValidateUser(user models.User) error {
v := validator.New()
err := v.Struct(user)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
Log(e.Error(), 2)
}
}
return err
}
// DeleteUser - deletes a given user
func DeleteUser(user string) (bool, error) {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return false, errors.New("user does not exist")
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
if err != nil {
return false, err
}
return true, nil
}
// FetchAuthSecret - manages secrets for oauth
func FetchAuthSecret(key string, secret string) (string, error) {
var record, err = database.FetchRecord(database.GENERATED_TABLE_NAME, key)
if err != nil {
if err = database.Insert(key, secret, database.GENERATED_TABLE_NAME); err != nil {
return "", err
} else {
return secret, nil
}
}
return record, nil
}

View file

@ -5,10 +5,17 @@ import (
"os/exec"
"strings"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
)
// CheckNetworkExists - checks i a network exists for this netmaker instance
func CheckNetworkExists(network string) bool {
var _, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, network)
return err == nil
}
// GetLocalIP - gets the local ip
func GetLocalIP(node models.Node) string {

View file

@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"log"
"math/rand"
"strconv"
"strings"
"time"
@ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin
return peers, err
}
// RandomString - returns a random string in a charset
func RandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
func setPeerInfo(node models.Node) models.Node {
var peer models.Node
peer.RelayAddrs = node.RelayAddrs
@ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node {
func Log(message string, loglevel int) {
log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile))
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 {
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 {
log.Println("[netmaker] " + message)
}
}

View file

@ -62,10 +62,11 @@ func setWGConfig(node models.Node, network string, peerupdate bool) error {
var iface string
iface = node.Interface
err = setServerPeers(iface, node.PersistentKeepalive, peers)
Log("updated peers on server "+node.Name, 2)
} else {
err = initWireguard(&node, privkey, peers, hasGateway, gateways)
Log("finished setting wg config on server "+node.Name, 3)
}
Log("finished setting wg config on server "+node.Name, 1)
return err
}

39
main.go
View file

@ -13,11 +13,13 @@ import (
"sync"
"time"
"github.com/gravitl/netmaker/auth"
controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/dnslogic"
"github.com/gravitl/netmaker/functions"
nodepb "github.com/gravitl/netmaker/grpc"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg"
@ -35,20 +37,29 @@ func main() {
func initialize() { // Client Mode Prereq Check
var err error
if err = database.InitializeDatabase(); err != nil {
log.Println("Error connecting to database.")
logic.Log("Error connecting to database", 0)
log.Fatal(err)
}
log.Println("database successfully connected.")
logic.Log("database successfully connected", 0)
var authProvider = auth.InitializeAuthProvider()
if authProvider != "" {
logic.Log("OAuth provider, "+authProvider+", initialized", 0)
} else {
logic.Log("no OAuth provider found or not configured, continuing without OAuth", 0)
}
if servercfg.IsClientMode() != "off" {
output, err := ncutils.RunCmd("id -u", true)
if err != nil {
log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.")
logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0)
log.Fatal(output, err)
}
uid, err := strconv.Atoi(string(output[:len(output)-1]))
if err != nil {
log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.")
logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0)
log.Fatal(err)
}
if uid != 0 {
@ -74,7 +85,7 @@ func startControllers() {
if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" {
err := servercfg.SetHost()
if err != nil {
log.Println("Unable to Set host. Exiting...")
logic.Log("Unable to Set host. Exiting...", 0)
log.Fatal(err)
}
}
@ -90,7 +101,7 @@ func startControllers() {
if servercfg.IsDNSMode() {
err := dnslogic.SetDNS()
if err != nil {
log.Println("error occurred initializing DNS:", err)
logic.Log("error occurred initializing DNS: "+err.Error(), 0)
}
}
//Run Rest Server
@ -98,7 +109,7 @@ func startControllers() {
if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" {
err := servercfg.SetHost()
if err != nil {
log.Println("Unable to Set host. Exiting...")
logic.Log("Unable to Set host. Exiting...", 0)
log.Fatal(err)
}
}
@ -106,11 +117,11 @@ func startControllers() {
controller.HandleRESTRequests(&waitnetwork)
}
if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() {
log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.")
logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0)
}
waitnetwork.Wait()
log.Println("[netmaker] exiting")
logic.Log("exiting", 0)
}
func runClient(wg *sync.WaitGroup) {
@ -139,7 +150,7 @@ func runGRPC(wg *sync.WaitGroup) {
listener, err := net.Listen("tcp", ":"+grpcport)
// Handle errors if any
if err != nil {
log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err)
log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err)
}
s := grpc.NewServer(
@ -157,7 +168,7 @@ func runGRPC(wg *sync.WaitGroup) {
log.Fatalf("Failed to serve: %v", err)
}
}()
log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)")
logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0)
// Right way to stop the server using a SHUTDOWN HOOK
// Create a channel to receive OS signals
@ -172,11 +183,11 @@ func runGRPC(wg *sync.WaitGroup) {
<-c
// After receiving CTRL+C Properly stop the server
log.Println("Stopping the Agent server...")
logic.Log("Stopping the Agent server...", 0)
s.Stop()
listener.Close()
log.Println("Agent server closed..")
log.Println("Closed DB connection.")
logic.Log("Agent server closed..", 0)
logic.Log("Closed DB connection.", 0)
}
func authServerUnaryInterceptor() grpc.ServerOption {

View file

@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"os/exec"
nodepb "github.com/gravitl/netmaker/grpc"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/auth"
@ -18,6 +16,8 @@ import (
"github.com/gravitl/netmaker/netclient/wireguard"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"log"
"os/exec"
)
// JoinNetwork - helps a client join a network
@ -84,8 +84,8 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error {
if ncutils.IsLinux() {
_, err := exec.LookPath("resolvectl")
if err != nil {
ncutils.PrintLog("resolvectl not present",2)
ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management",2)
ncutils.PrintLog("resolvectl not present", 2)
ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management", 2)
cfg.Node.DNSOn = "no"
}
}

View file

@ -7,6 +7,7 @@ import (
"net/http"
"os"
"strconv"
"strings"
"github.com/gravitl/netmaker/config"
)
@ -65,8 +66,26 @@ func GetServerConfig() config.ServerConfig {
cfg.Database = GetDB()
cfg.Platform = GetPlatform()
cfg.Version = GetVersion()
// == auth config ==
var authInfo = GetAuthProviderInfo()
cfg.AuthProvider = authInfo[0]
cfg.ClientID = authInfo[1]
cfg.ClientSecret = authInfo[2]
cfg.FrontendURL = GetFrontendURL()
return cfg
}
func GetFrontendURL() string {
var frontend = ""
if os.Getenv("FRONTEND_URL") != "" {
frontend = os.Getenv("FRONTEND_URL")
} else if config.Config.Server.FrontendURL != "" {
frontend = config.Config.Server.FrontendURL
}
return frontend
}
func GetAPIConnString() string {
conn := ""
if os.Getenv("SERVER_API_CONN_STRING") != "" {
@ -77,7 +96,7 @@ func GetAPIConnString() string {
return conn
}
func GetVersion() string {
version := "0.8.4"
version := "0.8.5"
if config.Config.Server.Version != "" {
version = config.Config.Server.Version
}
@ -398,6 +417,25 @@ func GetServerCheckinInterval() int64 {
return t
}
// GetAuthProviderInfo = gets the oauth provider info
func GetAuthProviderInfo() []string {
var authProvider = ""
if os.Getenv("AUTH_PROVIDER") != "" && os.Getenv("CLIENT_ID") != "" && os.Getenv("CLIENT_SECRET") != "" {
authProvider = strings.ToLower(os.Getenv("AUTH_PROVIDER"))
if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
return []string{authProvider, os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET")}
} else {
authProvider = ""
}
} else if config.Config.Server.AuthProvider != "" && config.Config.Server.ClientID != "" && config.Config.Server.ClientSecret != "" {
authProvider = strings.ToLower(config.Config.Server.AuthProvider)
if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
return []string{authProvider, config.Config.Server.ClientID, config.Config.Server.ClientSecret}
}
}
return []string{"", "", ""}
}
// GetMacAddr - get's mac address
func getMacAddr() string {
ifas, err := net.Interfaces()

View file

@ -1,8 +1,8 @@
package servercfg
import (
"os"
"github.com/gravitl/netmaker/config"
"os"
"strconv"
)