mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-24 22:15:36 +08:00
began oauth implementation
This commit is contained in:
parent
8a54f50676
commit
4e4e8b3ab5
11 changed files with 330 additions and 243 deletions
31
auth/auth.go
31
auth/auth.go
|
@ -1,6 +1,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
|
@ -13,14 +14,20 @@ const (
|
|||
get_user_info = "getuserinfo"
|
||||
handle_callback = "handlecallback"
|
||||
handle_login = "handlelogin"
|
||||
oauth_state_string = "netmaker-oauth-state"
|
||||
google_provider_name = "google"
|
||||
azure_ad_provider_name = "azure-ad"
|
||||
github_provider_name = "github"
|
||||
verify_user = "verifyuser"
|
||||
)
|
||||
|
||||
var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
|
||||
var auth_provider *oauth2.Config
|
||||
|
||||
type OauthUser struct {
|
||||
Email string `json:"email" bson:"email"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
func getCurrentAuthFunctions() map[string]interface{} {
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var authProvider = authInfo[0]
|
||||
|
@ -37,14 +44,14 @@ func getCurrentAuthFunctions() map[string]interface{} {
|
|||
}
|
||||
|
||||
// InitializeAuthProvider - initializes the auth provider if any is present
|
||||
func InitializeAuthProvider() bool {
|
||||
func InitializeAuthProvider() string {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return false
|
||||
return ""
|
||||
}
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString(), authInfo[1], authInfo[2])
|
||||
return auth_provider != nil
|
||||
functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
// HandleAuthCallback - handles oauth callback
|
||||
|
@ -64,3 +71,17 @@ func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
|
||||
// VerifyUserToken - checks if oauth2 token is valid
|
||||
func VerifyUserToken(accessToken string) bool {
|
||||
var token = &oauth2.Token{}
|
||||
var err = json.Unmarshal([]byte(accessToken), token)
|
||||
if err != nil || !token.Valid() {
|
||||
return false
|
||||
}
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return false
|
||||
}
|
||||
return functions[verify_user].(func(*oauth2.Token) bool)(token)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
@ -14,6 +17,7 @@ var google_functions = map[string]interface{}{
|
|||
get_user_info: getUserInfo,
|
||||
handle_callback: handleGoogleCallback,
|
||||
handle_login: handleGoogleLogin,
|
||||
verify_user: verifyGoogleUser,
|
||||
}
|
||||
|
||||
// == handle google authentication here ==
|
||||
|
@ -29,6 +33,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) {
|
|||
}
|
||||
|
||||
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
oauth_state_string = logic.RandomString(16)
|
||||
url := auth_provider.AuthCodeURL(oauth_state_string)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
@ -38,20 +43,26 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code"))
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
http.Redirect(w, r, "/api/oauth/error", http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "Content: %s\n", content)
|
||||
logic.Log("completed google oauth sigin in for "+content.Email, 0)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth="+content.AccessToken+"&email="+content.Email, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func getUserInfo(state string, code string) ([]byte, error) {
|
||||
func getUserInfo(state string, code string) (*OauthUser, error) {
|
||||
if state != oauth_state_string {
|
||||
return nil, fmt.Errorf("invalid oauth state")
|
||||
}
|
||||
token, err := auth_provider.Exchange(oauth2.NoContext, code)
|
||||
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
|
||||
}
|
||||
var data []byte
|
||||
data, err = json.Marshal(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
|
||||
}
|
||||
response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
|
||||
|
@ -61,5 +72,19 @@ func getUserInfo(state string, code string) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
|
||||
}
|
||||
return contents, nil
|
||||
var userInfo = &OauthUser{}
|
||||
if err = json.Unmarshal(contents, userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
|
||||
}
|
||||
userInfo.AccessToken = string(data)
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func verifyGoogleUser(token *oauth2.Token) bool {
|
||||
if token.Valid() {
|
||||
var err error
|
||||
_, err = http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
|
||||
return err == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -65,6 +65,7 @@ type ServerConfig struct {
|
|||
AuthProvider string `yaml:"authprovider"`
|
||||
ClientID string `yaml:"clientid"`
|
||||
ClientSecret string `yaml:"clientsecret"`
|
||||
FrontendURL string `yaml:"frontendurl"`
|
||||
}
|
||||
|
||||
// Generic SQL Config
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
|
@ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
|
|||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Println("REST Server successfully started on port " + port + " (REST)")
|
||||
logic.Log("REST Server successfully started on port "+port+" (REST)", 0)
|
||||
c := make(chan os.Signal)
|
||||
|
||||
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
|
||||
|
@ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
|
|||
<-c
|
||||
|
||||
// After receiving CTRL+C Properly stop the server
|
||||
log.Println("Stopping the REST server...")
|
||||
logic.Log("Stopping the REST server...", 0)
|
||||
srv.Shutdown(context.TODO())
|
||||
log.Println("REST Server closed.")
|
||||
logic.Log("REST Server closed.", 0)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1)
|
||||
returnSuccessResponse(w, r, params["clientid"]+" deleted.")
|
||||
}
|
||||
|
||||
// StringWithCharset - returns a random string in a charset
|
||||
func StringWithCharset(length int, charset string) string {
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
|
|
@ -3,14 +3,14 @@ package controller
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -26,6 +26,13 @@ func userHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE")
|
||||
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET")
|
||||
r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET")
|
||||
}
|
||||
|
||||
func throwOauthError(response http.ResponseWriter, request *http.Request) {
|
||||
returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized"))
|
||||
}
|
||||
|
||||
// Node authenticates using its password and retrieves a JWT for authorization.
|
||||
|
@ -181,37 +188,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// HasAdmin - checks if server has an admin
|
||||
func HasAdmin() (bool, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return false, nil
|
||||
} else {
|
||||
return true, err
|
||||
|
||||
}
|
||||
}
|
||||
for _, value := range collection { // filter for isadmin true
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if user.IsAdmin {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func hasAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
hasadmin, err := HasAdmin()
|
||||
hasadmin, err := logic.HasAdmin()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -221,20 +202,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
}
|
||||
|
||||
// GetUser - gets a user
|
||||
func GetUser(username string) (models.ReturnUser, error) {
|
||||
|
||||
var user models.ReturnUser
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &user); err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
// GetUserInternal - gets an internal user
|
||||
func GetUserInternal(username string) (models.User, error) {
|
||||
|
||||
|
@ -249,30 +216,6 @@ func GetUserInternal(username string) (models.User, error) {
|
|||
return user, err
|
||||
}
|
||||
|
||||
// GetUsers - gets users
|
||||
func GetUsers() ([]models.ReturnUser, error) {
|
||||
|
||||
var users []models.ReturnUser
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
|
||||
var user models.ReturnUser
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue // get users
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// Get an individual node. Nothin fancy here folks.
|
||||
func getUser(w http.ResponseWriter, r *http.Request) {
|
||||
// set header.
|
||||
|
@ -280,7 +223,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var params = mux.Vars(r)
|
||||
usernameFetched := params["username"]
|
||||
user, err := GetUser(usernameFetched)
|
||||
user, err := logic.GetUser(usernameFetched)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -295,7 +238,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
|
|||
// set header.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
users, err := GetUsers()
|
||||
users, err := logic.GetUsers()
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -306,42 +249,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(users)
|
||||
}
|
||||
|
||||
// CreateUser - creates a user
|
||||
func CreateUser(user models.User) (models.User, error) {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err == nil {
|
||||
return models.User{}, errors.New("user exists")
|
||||
}
|
||||
err := ValidateUser("create", user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
user.Password = string(hash)
|
||||
|
||||
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
|
||||
|
||||
if tokenString == "" {
|
||||
// returnErrorResponse(w, r, errorResponse)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// connect db
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
func createAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
@ -349,7 +256,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
// get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&admin)
|
||||
|
||||
admin, err := CreateAdmin(admin)
|
||||
admin, err := logic.CreateAdmin(admin)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
|
@ -359,18 +266,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(admin)
|
||||
}
|
||||
|
||||
func CreateAdmin(admin models.User) (models.User, error) {
|
||||
hasadmin, err := HasAdmin()
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if hasadmin {
|
||||
return models.User{}, errors.New("admin user already exists")
|
||||
}
|
||||
admin.IsAdmin = true
|
||||
return CreateUser(admin)
|
||||
}
|
||||
|
||||
func createUser(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
@ -378,7 +273,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
|
|||
// get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&user)
|
||||
|
||||
user, err := CreateUser(user)
|
||||
user, err := logic.CreateUser(user)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
|
@ -388,52 +283,6 @@ func createUser(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// UpdateUser - updates a given user
|
||||
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
|
||||
//check if user exists
|
||||
if _, err := GetUser(user.UserName); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
err := ValidateUser("update", userchange)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
queryUser := user.UserName
|
||||
|
||||
if userchange.UserName != "" {
|
||||
user.UserName = userchange.UserName
|
||||
}
|
||||
if len(userchange.Networks) > 0 {
|
||||
user.Networks = userchange.Networks
|
||||
}
|
||||
if userchange.Password != "" {
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
|
||||
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
userchange.Password = string(hash)
|
||||
|
||||
user.Password = userchange.Password
|
||||
}
|
||||
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
|
@ -453,7 +302,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
userchange.Networks = nil
|
||||
user, err = UpdateUser(userchange, user)
|
||||
user, err = logic.UpdateUser(userchange, user)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -480,7 +329,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
user, err = UpdateUser(userchange, user)
|
||||
user, err = logic.UpdateUser(userchange, user)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -489,20 +338,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) (bool, error) {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return false, errors.New("user does not exist")
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
@ -511,7 +346,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
var params = mux.Vars(r)
|
||||
|
||||
username := params["username"]
|
||||
success, err := DeleteUser(username)
|
||||
success, err := logic.DeleteUser(username)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -524,17 +359,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
functions.PrintUserLog(username, "was deleted", 1)
|
||||
json.NewEncoder(w).Encode(params["username"] + " deleted.")
|
||||
}
|
||||
|
||||
// ValidateUser - validates a user model
|
||||
func ValidateUser(operation string, user models.User) error {
|
||||
|
||||
v := validator.New()
|
||||
err := v.Struct(user)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
fmt.Println(e)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -87,11 +87,11 @@ func getCurrentDB() map[string]interface{} {
|
|||
}
|
||||
|
||||
func InitializeDatabase() error {
|
||||
log.Println("connecting to", servercfg.GetDB())
|
||||
log.Println("[netmaker] connecting to", servercfg.GetDB())
|
||||
tperiod := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil {
|
||||
log.Println("unable to connect to db, retrying . . .")
|
||||
log.Println("[netmaker] unable to connect to db, retrying . . .")
|
||||
if time.Now().After(tperiod) {
|
||||
return err
|
||||
}
|
||||
|
|
199
logic/auth.go
Normal file
199
logic/auth.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HasAdmin - checks if server has an admin
|
||||
func HasAdmin() (bool, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return false, nil
|
||||
} else {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
for _, value := range collection { // filter for isadmin true
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if user.IsAdmin {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// GetUser - gets a user
|
||||
func GetUser(username string) (models.ReturnUser, error) {
|
||||
|
||||
var user models.ReturnUser
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &user); err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
// GetUsers - gets users
|
||||
func GetUsers() ([]models.ReturnUser, error) {
|
||||
|
||||
var users []models.ReturnUser
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
|
||||
var user models.ReturnUser
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue // get users
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// CreateUser - creates a user
|
||||
func CreateUser(user models.User) (models.User, error) {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err == nil {
|
||||
return models.User{}, errors.New("user exists")
|
||||
}
|
||||
var err = ValidateUser(user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
user.Password = string(hash)
|
||||
|
||||
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
|
||||
|
||||
if tokenString == "" {
|
||||
// returnErrorResponse(w, r, errorResponse)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// connect db
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
// CreateAdmin - creates an admin user
|
||||
func CreateAdmin(admin models.User) (models.User, error) {
|
||||
hasadmin, err := HasAdmin()
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if hasadmin {
|
||||
return models.User{}, errors.New("admin user already exists")
|
||||
}
|
||||
admin.IsAdmin = true
|
||||
return CreateUser(admin)
|
||||
}
|
||||
|
||||
// UpdateUser - updates a given user
|
||||
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
|
||||
//check if user exists
|
||||
if _, err := GetUser(user.UserName); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
err := ValidateUser(userchange)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
queryUser := user.UserName
|
||||
|
||||
if userchange.UserName != "" {
|
||||
user.UserName = userchange.UserName
|
||||
}
|
||||
if len(userchange.Networks) > 0 {
|
||||
user.Networks = userchange.Networks
|
||||
}
|
||||
if userchange.Password != "" {
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
|
||||
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
userchange.Password = string(hash)
|
||||
|
||||
user.Password = userchange.Password
|
||||
}
|
||||
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ValidateUser - validates a user model
|
||||
func ValidateUser(user models.User) error {
|
||||
|
||||
v := validator.New()
|
||||
err := v.Struct(user)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
Log(e.Error(), 2)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) (bool, error) {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return false, errors.New("user does not exist")
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin
|
|||
return peers, err
|
||||
}
|
||||
|
||||
// RandomString - returns a random string in a charset
|
||||
func RandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func setPeerInfo(node models.Node) models.Node {
|
||||
var peer models.Node
|
||||
peer.RelayAddrs = node.RelayAddrs
|
||||
|
@ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node {
|
|||
|
||||
func Log(message string, loglevel int) {
|
||||
log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile))
|
||||
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 {
|
||||
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 {
|
||||
log.Println("[netmaker] " + message)
|
||||
}
|
||||
}
|
||||
|
|
37
main.go
37
main.go
|
@ -13,11 +13,13 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/dnslogic"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
nodepb "github.com/gravitl/netmaker/grpc"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
|
@ -35,20 +37,27 @@ func main() {
|
|||
|
||||
func initialize() { // Client Mode Prereq Check
|
||||
var err error
|
||||
|
||||
if err = database.InitializeDatabase(); err != nil {
|
||||
log.Println("Error connecting to database.")
|
||||
logic.Log("Error connecting to database", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Println("database successfully connected.")
|
||||
logic.Log("database successfully connected", 0)
|
||||
|
||||
var authProvider = auth.InitializeAuthProvider()
|
||||
if authProvider != "" {
|
||||
logic.Log("OAuth provider, "+authProvider+", initialized", 0)
|
||||
}
|
||||
|
||||
if servercfg.IsClientMode() != "off" {
|
||||
output, err := ncutils.RunCmd("id -u", true)
|
||||
if err != nil {
|
||||
log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.")
|
||||
logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0)
|
||||
log.Fatal(output, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(string(output[:len(output)-1]))
|
||||
if err != nil {
|
||||
log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.")
|
||||
logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
if uid != 0 {
|
||||
|
@ -74,7 +83,7 @@ func startControllers() {
|
|||
if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" {
|
||||
err := servercfg.SetHost()
|
||||
if err != nil {
|
||||
log.Println("Unable to Set host. Exiting...")
|
||||
logic.Log("Unable to Set host. Exiting...", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +99,7 @@ func startControllers() {
|
|||
if servercfg.IsDNSMode() {
|
||||
err := dnslogic.SetDNS()
|
||||
if err != nil {
|
||||
log.Println("error occurred initializing DNS:", err)
|
||||
logic.Log("error occurred initializing DNS: "+err.Error(), 0)
|
||||
}
|
||||
}
|
||||
//Run Rest Server
|
||||
|
@ -98,7 +107,7 @@ func startControllers() {
|
|||
if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" {
|
||||
err := servercfg.SetHost()
|
||||
if err != nil {
|
||||
log.Println("Unable to Set host. Exiting...")
|
||||
logic.Log("Unable to Set host. Exiting...", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -106,11 +115,11 @@ func startControllers() {
|
|||
controller.HandleRESTRequests(&waitnetwork)
|
||||
}
|
||||
if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() {
|
||||
log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.")
|
||||
logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0)
|
||||
}
|
||||
|
||||
waitnetwork.Wait()
|
||||
log.Println("[netmaker] exiting")
|
||||
logic.Log("exiting", 0)
|
||||
}
|
||||
|
||||
func runClient(wg *sync.WaitGroup) {
|
||||
|
@ -139,7 +148,7 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
listener, err := net.Listen("tcp", ":"+grpcport)
|
||||
// Handle errors if any
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err)
|
||||
log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err)
|
||||
}
|
||||
|
||||
s := grpc.NewServer(
|
||||
|
@ -157,7 +166,7 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
log.Fatalf("Failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)")
|
||||
logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0)
|
||||
|
||||
// Right way to stop the server using a SHUTDOWN HOOK
|
||||
// Create a channel to receive OS signals
|
||||
|
@ -172,11 +181,11 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
<-c
|
||||
|
||||
// After receiving CTRL+C Properly stop the server
|
||||
log.Println("Stopping the Agent server...")
|
||||
logic.Log("Stopping the Agent server...", 0)
|
||||
s.Stop()
|
||||
listener.Close()
|
||||
log.Println("Agent server closed..")
|
||||
log.Println("Closed DB connection.")
|
||||
logic.Log("Agent server closed..", 0)
|
||||
logic.Log("Closed DB connection.", 0)
|
||||
}
|
||||
|
||||
func authServerUnaryInterceptor() grpc.ServerOption {
|
||||
|
|
|
@ -72,8 +72,20 @@ func GetServerConfig() config.ServerConfig {
|
|||
cfg.AuthProvider = authInfo[0]
|
||||
cfg.ClientID = authInfo[1]
|
||||
cfg.ClientSecret = authInfo[2]
|
||||
cfg.FrontendURL = GetFrontendURL()
|
||||
|
||||
return cfg
|
||||
}
|
||||
func GetFrontendURL() string {
|
||||
var frontend = ""
|
||||
if os.Getenv("FRONTEND_URL") != "" {
|
||||
frontend = os.Getenv("FRONTEND_URL")
|
||||
} else if config.Config.Server.FrontendURL != "" {
|
||||
frontend = config.Config.Server.FrontendURL
|
||||
}
|
||||
return frontend
|
||||
}
|
||||
|
||||
func GetAPIConnString() string {
|
||||
conn := ""
|
||||
if os.Getenv("SERVER_API_CONN_STRING") != "" {
|
||||
|
@ -84,7 +96,7 @@ func GetAPIConnString() string {
|
|||
return conn
|
||||
}
|
||||
func GetVersion() string {
|
||||
version := "0.8.4"
|
||||
version := "0.8.5"
|
||||
if config.Config.Server.Version != "" {
|
||||
version = config.Config.Server.Version
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue