began oauth implementation

This commit is contained in:
0xdcarns 2021-10-21 15:28:58 -04:00
parent 8a54f50676
commit 4e4e8b3ab5
11 changed files with 330 additions and 243 deletions

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
@ -13,14 +14,20 @@ const (
get_user_info = "getuserinfo" get_user_info = "getuserinfo"
handle_callback = "handlecallback" handle_callback = "handlecallback"
handle_login = "handlelogin" handle_login = "handlelogin"
oauth_state_string = "netmaker-oauth-state"
google_provider_name = "google" google_provider_name = "google"
azure_ad_provider_name = "azure-ad" azure_ad_provider_name = "azure-ad"
github_provider_name = "github" github_provider_name = "github"
verify_user = "verifyuser"
) )
var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
var auth_provider *oauth2.Config var auth_provider *oauth2.Config
type OauthUser struct {
Email string `json:"email" bson:"email"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
func getCurrentAuthFunctions() map[string]interface{} { func getCurrentAuthFunctions() map[string]interface{} {
var authInfo = servercfg.GetAuthProviderInfo() var authInfo = servercfg.GetAuthProviderInfo()
var authProvider = authInfo[0] var authProvider = authInfo[0]
@ -37,14 +44,14 @@ func getCurrentAuthFunctions() map[string]interface{} {
} }
// InitializeAuthProvider - initializes the auth provider if any is present // InitializeAuthProvider - initializes the auth provider if any is present
func InitializeAuthProvider() bool { func InitializeAuthProvider() string {
var functions = getCurrentAuthFunctions() var functions = getCurrentAuthFunctions()
if functions == nil { if functions == nil {
return false return ""
} }
var authInfo = servercfg.GetAuthProviderInfo() var authInfo = servercfg.GetAuthProviderInfo()
functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString(), authInfo[1], authInfo[2]) functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2])
return auth_provider != nil return authInfo[0]
} }
// HandleAuthCallback - handles oauth callback // HandleAuthCallback - handles oauth callback
@ -64,3 +71,17 @@ func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
} }
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r) functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
} }
// VerifyUserToken - checks if oauth2 token is valid
func VerifyUserToken(accessToken string) bool {
var token = &oauth2.Token{}
var err = json.Unmarshal([]byte(accessToken), token)
if err != nil || !token.Valid() {
return false
}
var functions = getCurrentAuthFunctions()
if functions == nil {
return false
}
return functions[verify_user].(func(*oauth2.Token) bool)(token)
}

View file

@ -1,10 +1,13 @@
package auth package auth
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
@ -14,6 +17,7 @@ var google_functions = map[string]interface{}{
get_user_info: getUserInfo, get_user_info: getUserInfo,
handle_callback: handleGoogleCallback, handle_callback: handleGoogleCallback,
handle_login: handleGoogleLogin, handle_login: handleGoogleLogin,
verify_user: verifyGoogleUser,
} }
// == handle google authentication here == // == handle google authentication here ==
@ -29,6 +33,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) {
} }
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
oauth_state_string = logic.RandomString(16)
url := auth_provider.AuthCodeURL(oauth_state_string) url := auth_provider.AuthCodeURL(oauth_state_string)
http.Redirect(w, r, url, http.StatusTemporaryRedirect) http.Redirect(w, r, url, http.StatusTemporaryRedirect)
} }
@ -38,20 +43,26 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code")) var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code"))
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
http.Redirect(w, r, "/api/oauth/error", http.StatusTemporaryRedirect) http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
return return
} }
fmt.Fprintf(w, "Content: %s\n", content) logic.Log("completed google oauth sigin in for "+content.Email, 0)
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth="+content.AccessToken+"&email="+content.Email, http.StatusPermanentRedirect)
} }
func getUserInfo(state string, code string) ([]byte, error) { func getUserInfo(state string, code string) (*OauthUser, error) {
if state != oauth_state_string { if state != oauth_state_string {
return nil, fmt.Errorf("invalid oauth state") return nil, fmt.Errorf("invalid oauth state")
} }
token, err := auth_provider.Exchange(oauth2.NoContext, code) var token, err = auth_provider.Exchange(oauth2.NoContext, code)
if err != nil { if err != nil {
return nil, fmt.Errorf("code exchange failed: %s", err.Error()) return nil, fmt.Errorf("code exchange failed: %s", err.Error())
} }
var data []byte
data, err = json.Marshal(token)
if err != nil {
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
}
response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed getting user info: %s", err.Error()) return nil, fmt.Errorf("failed getting user info: %s", err.Error())
@ -61,5 +72,19 @@ func getUserInfo(state string, code string) ([]byte, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed reading response body: %s", err.Error()) return nil, fmt.Errorf("failed reading response body: %s", err.Error())
} }
return contents, nil var userInfo = &OauthUser{}
if err = json.Unmarshal(contents, userInfo); err != nil {
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
}
userInfo.AccessToken = string(data)
return userInfo, nil
}
func verifyGoogleUser(token *oauth2.Token) bool {
if token.Valid() {
var err error
_, err = http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
return err == nil
}
return false
} }

View file

@ -65,6 +65,7 @@ type ServerConfig struct {
AuthProvider string `yaml:"authprovider"` AuthProvider string `yaml:"authprovider"`
ClientID string `yaml:"clientid"` ClientID string `yaml:"clientid"`
ClientSecret string `yaml:"clientsecret"` ClientSecret string `yaml:"clientsecret"`
FrontendURL string `yaml:"frontendurl"`
} }
// Generic SQL Config // Generic SQL Config

View file

@ -10,6 +10,7 @@ import (
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
) )
@ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
log.Println(err) log.Println(err)
} }
}() }()
logic.Log("REST Server successfully started on port "+port+" (REST)", 0)
log.Println("REST Server successfully started on port " + port + " (REST)")
c := make(chan os.Signal) c := make(chan os.Signal)
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C) // Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
@ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
<-c <-c
// After receiving CTRL+C Properly stop the server // After receiving CTRL+C Properly stop the server
log.Println("Stopping the REST server...") logic.Log("Stopping the REST server...", 0)
srv.Shutdown(context.TODO()) srv.Shutdown(context.TODO())
log.Println("REST Server closed.") logic.Log("REST Server closed.", 0)
} }

View file

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1) "Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1)
returnSuccessResponse(w, r, params["clientid"]+" deleted.") returnSuccessResponse(w, r, params["clientid"]+" deleted.")
} }
// StringWithCharset - returns a random string in a charset
func StringWithCharset(length int, charset string) string {
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand *rand.Rand = rand.New(
rand.NewSource(time.Now().UnixNano()))

View file

@ -3,14 +3,14 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"net/http" "net/http"
"strings" "strings"
"github.com/go-playground/validator/v10"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -26,6 +26,13 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE") r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE")
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET") r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET")
r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET") r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET")
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET")
}
func throwOauthError(response http.ResponseWriter, request *http.Request) {
returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized"))
} }
// Node authenticates using its password and retrieves a JWT for authorization. // Node authenticates using its password and retrieves a JWT for authorization.
@ -181,37 +188,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error {
return nil return nil
} }
// HasAdmin - checks if server has an admin
func HasAdmin() (bool, error) {
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
} else {
return true, err
}
}
for _, value := range collection { // filter for isadmin true
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
if user.IsAdmin {
return true, nil
}
}
return false, err
}
func hasAdmin(w http.ResponseWriter, r *http.Request) { func hasAdmin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
hasadmin, err := HasAdmin() hasadmin, err := logic.HasAdmin()
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
@ -221,20 +202,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
} }
// GetUser - gets a user
func GetUser(username string) (models.ReturnUser, error) {
var user models.ReturnUser
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
return user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return models.ReturnUser{}, err
}
return user, err
}
// GetUserInternal - gets an internal user // GetUserInternal - gets an internal user
func GetUserInternal(username string) (models.User, error) { func GetUserInternal(username string) (models.User, error) {
@ -249,30 +216,6 @@ func GetUserInternal(username string) (models.User, error) {
return user, err return user, err
} }
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
var users []models.ReturnUser
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
for _, value := range collection {
var user models.ReturnUser
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
users = append(users, user)
}
return users, err
}
// Get an individual node. Nothin fancy here folks. // Get an individual node. Nothin fancy here folks.
func getUser(w http.ResponseWriter, r *http.Request) { func getUser(w http.ResponseWriter, r *http.Request) {
// set header. // set header.
@ -280,7 +223,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r) var params = mux.Vars(r)
usernameFetched := params["username"] usernameFetched := params["username"]
user, err := GetUser(usernameFetched) user, err := logic.GetUser(usernameFetched)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
@ -295,7 +238,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
// set header. // set header.
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
users, err := GetUsers() users, err := logic.GetUsers()
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
@ -306,42 +249,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(users) json.NewEncoder(w).Encode(users)
} }
// CreateUser - creates a user
func CreateUser(user models.User) (models.User, error) {
// check if user exists
if _, err := GetUser(user.UserName); err == nil {
return models.User{}, errors.New("user exists")
}
err := ValidateUser("create", user)
if err != nil {
return models.User{}, err
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
if err != nil {
return user, err
}
// set password to encrypted password
user.Password = string(hash)
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
if tokenString == "" {
// returnErrorResponse(w, r, errorResponse)
return user, err
}
// connect db
data, err := json.Marshal(&user)
if err != nil {
return user, err
}
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
return user, err
}
func createAdmin(w http.ResponseWriter, r *http.Request) { func createAdmin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -349,7 +256,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
// get node from body of request // get node from body of request
_ = json.NewDecoder(r.Body).Decode(&admin) _ = json.NewDecoder(r.Body).Decode(&admin)
admin, err := CreateAdmin(admin) admin, err := logic.CreateAdmin(admin)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) returnErrorResponse(w, r, formatError(err, "badrequest"))
@ -359,18 +266,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(admin) json.NewEncoder(w).Encode(admin)
} }
func CreateAdmin(admin models.User) (models.User, error) {
hasadmin, err := HasAdmin()
if err != nil {
return models.User{}, err
}
if hasadmin {
return models.User{}, errors.New("admin user already exists")
}
admin.IsAdmin = true
return CreateUser(admin)
}
func createUser(w http.ResponseWriter, r *http.Request) { func createUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -378,7 +273,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
// get node from body of request // get node from body of request
_ = json.NewDecoder(r.Body).Decode(&user) _ = json.NewDecoder(r.Body).Decode(&user)
user, err := CreateUser(user) user, err := logic.CreateUser(user)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) returnErrorResponse(w, r, formatError(err, "badrequest"))
@ -388,52 +283,6 @@ func createUser(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(user) json.NewEncoder(w).Encode(user)
} }
// UpdateUser - updates a given user
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
//check if user exists
if _, err := GetUser(user.UserName); err != nil {
return models.User{}, err
}
err := ValidateUser("update", userchange)
if err != nil {
return models.User{}, err
}
queryUser := user.UserName
if userchange.UserName != "" {
user.UserName = userchange.UserName
}
if len(userchange.Networks) > 0 {
user.Networks = userchange.Networks
}
if userchange.Password != "" {
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
if err != nil {
return userchange, err
}
// set password to encrypted password
userchange.Password = string(hash)
user.Password = userchange.Password
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return models.User{}, err
}
data, err := json.Marshal(&user)
if err != nil {
return models.User{}, err
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return models.User{}, err
}
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
return user, nil
}
func updateUser(w http.ResponseWriter, r *http.Request) { func updateUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r) var params = mux.Vars(r)
@ -453,7 +302,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userchange.Networks = nil userchange.Networks = nil
user, err = UpdateUser(userchange, user) user, err = logic.UpdateUser(userchange, user)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) returnErrorResponse(w, r, formatError(err, "badrequest"))
return return
@ -480,7 +329,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
return return
} }
user, err = UpdateUser(userchange, user) user, err = logic.UpdateUser(userchange, user)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "badrequest")) returnErrorResponse(w, r, formatError(err, "badrequest"))
return return
@ -489,20 +338,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(user) json.NewEncoder(w).Encode(user)
} }
// DeleteUser - deletes a given user
func DeleteUser(user string) (bool, error) {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return false, errors.New("user does not exist")
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
if err != nil {
return false, err
}
return true, nil
}
func deleteUser(w http.ResponseWriter, r *http.Request) { func deleteUser(w http.ResponseWriter, r *http.Request) {
// Set header // Set header
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -511,7 +346,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r) var params = mux.Vars(r)
username := params["username"] username := params["username"]
success, err := DeleteUser(username) success, err := logic.DeleteUser(username)
if err != nil { if err != nil {
returnErrorResponse(w, r, formatError(err, "internal")) returnErrorResponse(w, r, formatError(err, "internal"))
@ -524,17 +359,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
functions.PrintUserLog(username, "was deleted", 1) functions.PrintUserLog(username, "was deleted", 1)
json.NewEncoder(w).Encode(params["username"] + " deleted.") json.NewEncoder(w).Encode(params["username"] + " deleted.")
} }
// ValidateUser - validates a user model
func ValidateUser(operation string, user models.User) error {
v := validator.New()
err := v.Struct(user)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
fmt.Println(e)
}
}
return err
}

View file

@ -87,11 +87,11 @@ func getCurrentDB() map[string]interface{} {
} }
func InitializeDatabase() error { func InitializeDatabase() error {
log.Println("connecting to", servercfg.GetDB()) log.Println("[netmaker] connecting to", servercfg.GetDB())
tperiod := time.Now().Add(10 * time.Second) tperiod := time.Now().Add(10 * time.Second)
for { for {
if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil { if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil {
log.Println("unable to connect to db, retrying . . .") log.Println("[netmaker] unable to connect to db, retrying . . .")
if time.Now().After(tperiod) { if time.Now().After(tperiod) {
return err return err
} }

199
logic/auth.go Normal file
View file

@ -0,0 +1,199 @@
package logic
import (
"encoding/json"
"errors"
"github.com/go-playground/validator/v10"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/functions"
"github.com/gravitl/netmaker/models"
"golang.org/x/crypto/bcrypt"
)
// HasAdmin - checks if server has an admin
func HasAdmin() (bool, error) {
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
} else {
return true, err
}
}
for _, value := range collection { // filter for isadmin true
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
if user.IsAdmin {
return true, nil
}
}
return false, err
}
// GetUser - gets a user
func GetUser(username string) (models.ReturnUser, error) {
var user models.ReturnUser
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
return user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return models.ReturnUser{}, err
}
return user, err
}
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
var users []models.ReturnUser
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
for _, value := range collection {
var user models.ReturnUser
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
users = append(users, user)
}
return users, err
}
// CreateUser - creates a user
func CreateUser(user models.User) (models.User, error) {
// check if user exists
if _, err := GetUser(user.UserName); err == nil {
return models.User{}, errors.New("user exists")
}
var err = ValidateUser(user)
if err != nil {
return models.User{}, err
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
if err != nil {
return user, err
}
// set password to encrypted password
user.Password = string(hash)
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
if tokenString == "" {
// returnErrorResponse(w, r, errorResponse)
return user, err
}
// connect db
data, err := json.Marshal(&user)
if err != nil {
return user, err
}
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
return user, err
}
// CreateAdmin - creates an admin user
func CreateAdmin(admin models.User) (models.User, error) {
hasadmin, err := HasAdmin()
if err != nil {
return models.User{}, err
}
if hasadmin {
return models.User{}, errors.New("admin user already exists")
}
admin.IsAdmin = true
return CreateUser(admin)
}
// UpdateUser - updates a given user
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
//check if user exists
if _, err := GetUser(user.UserName); err != nil {
return models.User{}, err
}
err := ValidateUser(userchange)
if err != nil {
return models.User{}, err
}
queryUser := user.UserName
if userchange.UserName != "" {
user.UserName = userchange.UserName
}
if len(userchange.Networks) > 0 {
user.Networks = userchange.Networks
}
if userchange.Password != "" {
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
if err != nil {
return userchange, err
}
// set password to encrypted password
userchange.Password = string(hash)
user.Password = userchange.Password
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return models.User{}, err
}
data, err := json.Marshal(&user)
if err != nil {
return models.User{}, err
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return models.User{}, err
}
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
return user, nil
}
// ValidateUser - validates a user model
func ValidateUser(user models.User) error {
v := validator.New()
err := v.Struct(user)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
Log(e.Error(), 2)
}
}
return err
}
// DeleteUser - deletes a given user
func DeleteUser(user string) (bool, error) {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return false, errors.New("user does not exist")
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
if err != nil {
return false, err
}
return true, nil
}

View file

@ -5,6 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"log" "log"
"math/rand"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin
return peers, err return peers, err
} }
// RandomString - returns a random string in a charset
func RandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
func setPeerInfo(node models.Node) models.Node { func setPeerInfo(node models.Node) models.Node {
var peer models.Node var peer models.Node
peer.RelayAddrs = node.RelayAddrs peer.RelayAddrs = node.RelayAddrs
@ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node {
func Log(message string, loglevel int) { func Log(message string, loglevel int) {
log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile)) log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile))
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 { if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 {
log.Println("[netmaker] " + message) log.Println("[netmaker] " + message)
} }
} }

37
main.go
View file

@ -13,11 +13,13 @@ import (
"sync" "sync"
"time" "time"
"github.com/gravitl/netmaker/auth"
controller "github.com/gravitl/netmaker/controllers" controller "github.com/gravitl/netmaker/controllers"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/dnslogic" "github.com/gravitl/netmaker/dnslogic"
"github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/functions"
nodepb "github.com/gravitl/netmaker/grpc" nodepb "github.com/gravitl/netmaker/grpc"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/servercfg"
@ -35,20 +37,27 @@ func main() {
func initialize() { // Client Mode Prereq Check func initialize() { // Client Mode Prereq Check
var err error var err error
if err = database.InitializeDatabase(); err != nil { if err = database.InitializeDatabase(); err != nil {
log.Println("Error connecting to database.") logic.Log("Error connecting to database", 0)
log.Fatal(err) log.Fatal(err)
} }
log.Println("database successfully connected.") logic.Log("database successfully connected", 0)
var authProvider = auth.InitializeAuthProvider()
if authProvider != "" {
logic.Log("OAuth provider, "+authProvider+", initialized", 0)
}
if servercfg.IsClientMode() != "off" { if servercfg.IsClientMode() != "off" {
output, err := ncutils.RunCmd("id -u", true) output, err := ncutils.RunCmd("id -u", true)
if err != nil { if err != nil {
log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.") logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0)
log.Fatal(output, err) log.Fatal(output, err)
} }
uid, err := strconv.Atoi(string(output[:len(output)-1])) uid, err := strconv.Atoi(string(output[:len(output)-1]))
if err != nil { if err != nil {
log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.") logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0)
log.Fatal(err) log.Fatal(err)
} }
if uid != 0 { if uid != 0 {
@ -74,7 +83,7 @@ func startControllers() {
if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" { if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" {
err := servercfg.SetHost() err := servercfg.SetHost()
if err != nil { if err != nil {
log.Println("Unable to Set host. Exiting...") logic.Log("Unable to Set host. Exiting...", 0)
log.Fatal(err) log.Fatal(err)
} }
} }
@ -90,7 +99,7 @@ func startControllers() {
if servercfg.IsDNSMode() { if servercfg.IsDNSMode() {
err := dnslogic.SetDNS() err := dnslogic.SetDNS()
if err != nil { if err != nil {
log.Println("error occurred initializing DNS:", err) logic.Log("error occurred initializing DNS: "+err.Error(), 0)
} }
} }
//Run Rest Server //Run Rest Server
@ -98,7 +107,7 @@ func startControllers() {
if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" { if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" {
err := servercfg.SetHost() err := servercfg.SetHost()
if err != nil { if err != nil {
log.Println("Unable to Set host. Exiting...") logic.Log("Unable to Set host. Exiting...", 0)
log.Fatal(err) log.Fatal(err)
} }
} }
@ -106,11 +115,11 @@ func startControllers() {
controller.HandleRESTRequests(&waitnetwork) controller.HandleRESTRequests(&waitnetwork)
} }
if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() { if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() {
log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.") logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0)
} }
waitnetwork.Wait() waitnetwork.Wait()
log.Println("[netmaker] exiting") logic.Log("exiting", 0)
} }
func runClient(wg *sync.WaitGroup) { func runClient(wg *sync.WaitGroup) {
@ -139,7 +148,7 @@ func runGRPC(wg *sync.WaitGroup) {
listener, err := net.Listen("tcp", ":"+grpcport) listener, err := net.Listen("tcp", ":"+grpcport)
// Handle errors if any // Handle errors if any
if err != nil { if err != nil {
log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err) log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err)
} }
s := grpc.NewServer( s := grpc.NewServer(
@ -157,7 +166,7 @@ func runGRPC(wg *sync.WaitGroup) {
log.Fatalf("Failed to serve: %v", err) log.Fatalf("Failed to serve: %v", err)
} }
}() }()
log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)") logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0)
// Right way to stop the server using a SHUTDOWN HOOK // Right way to stop the server using a SHUTDOWN HOOK
// Create a channel to receive OS signals // Create a channel to receive OS signals
@ -172,11 +181,11 @@ func runGRPC(wg *sync.WaitGroup) {
<-c <-c
// After receiving CTRL+C Properly stop the server // After receiving CTRL+C Properly stop the server
log.Println("Stopping the Agent server...") logic.Log("Stopping the Agent server...", 0)
s.Stop() s.Stop()
listener.Close() listener.Close()
log.Println("Agent server closed..") logic.Log("Agent server closed..", 0)
log.Println("Closed DB connection.") logic.Log("Closed DB connection.", 0)
} }
func authServerUnaryInterceptor() grpc.ServerOption { func authServerUnaryInterceptor() grpc.ServerOption {

View file

@ -72,8 +72,20 @@ func GetServerConfig() config.ServerConfig {
cfg.AuthProvider = authInfo[0] cfg.AuthProvider = authInfo[0]
cfg.ClientID = authInfo[1] cfg.ClientID = authInfo[1]
cfg.ClientSecret = authInfo[2] cfg.ClientSecret = authInfo[2]
cfg.FrontendURL = GetFrontendURL()
return cfg return cfg
} }
func GetFrontendURL() string {
var frontend = ""
if os.Getenv("FRONTEND_URL") != "" {
frontend = os.Getenv("FRONTEND_URL")
} else if config.Config.Server.FrontendURL != "" {
frontend = config.Config.Server.FrontendURL
}
return frontend
}
func GetAPIConnString() string { func GetAPIConnString() string {
conn := "" conn := ""
if os.Getenv("SERVER_API_CONN_STRING") != "" { if os.Getenv("SERVER_API_CONN_STRING") != "" {
@ -84,7 +96,7 @@ func GetAPIConnString() string {
return conn return conn
} }
func GetVersion() string { func GetVersion() string {
version := "0.8.4" version := "0.8.5"
if config.Config.Server.Version != "" { if config.Config.Server.Version != "" {
version = config.Config.Server.Version version = config.Config.Server.Version
} }