mirror of
https://github.com/gravitl/netmaker.git
synced 2024-09-20 23:36:18 +08:00
Merge pull request #386 from gravitl/feature_v0.8.5_oauth
Feature v0.8.5 oauth
This commit is contained in:
commit
a22c029a51
|
@ -8,7 +8,7 @@
|
|||
|
||||
<p align="center">
|
||||
<a href="https://github.com/gravitl/netmaker/releases">
|
||||
<img src="https://img.shields.io/badge/Version-0.8.4-informational?style=flat-square" />
|
||||
<img src="https://img.shields.io/badge/Version-0.8.5-informational?style=flat-square" />
|
||||
</a>
|
||||
<a href="https://discord.gg/zRb9Vfhk8A">
|
||||
<img src="https://img.shields.io/badge/community-discord-informational" />
|
||||
|
|
157
auth/auth.go
Normal file
157
auth/auth.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// == consts ==
|
||||
const (
|
||||
init_provider = "initprovider"
|
||||
get_user_info = "getuserinfo"
|
||||
handle_callback = "handlecallback"
|
||||
handle_login = "handlelogin"
|
||||
google_provider_name = "google"
|
||||
azure_ad_provider_name = "azure-ad"
|
||||
github_provider_name = "github"
|
||||
verify_user = "verifyuser"
|
||||
auth_key = "netmaker_auth"
|
||||
)
|
||||
|
||||
var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
|
||||
var auth_provider *oauth2.Config
|
||||
|
||||
func getCurrentAuthFunctions() map[string]interface{} {
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var authProvider = authInfo[0]
|
||||
switch authProvider {
|
||||
case google_provider_name:
|
||||
return google_functions
|
||||
case azure_ad_provider_name:
|
||||
return azure_ad_functions
|
||||
case github_provider_name:
|
||||
return github_functions
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeAuthProvider - initializes the auth provider if any is present
|
||||
func InitializeAuthProvider() string {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return ""
|
||||
}
|
||||
var _, err = fetchPassValue(logic.RandomString(64))
|
||||
if err != nil {
|
||||
logic.Log(err.Error(), 0)
|
||||
return ""
|
||||
}
|
||||
var currentFrontendURL = servercfg.GetFrontendURL()
|
||||
if currentFrontendURL == "" {
|
||||
return ""
|
||||
}
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString()+"/api/oauth/callback", authInfo[1], authInfo[2])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
// HandleAuthCallback - handles oauth callback
|
||||
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
|
||||
// HandleAuthLogin - handles oauth login
|
||||
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
|
||||
// IsOauthUser - returns
|
||||
func IsOauthUser(user *models.User) error {
|
||||
var currentValue, err = fetchPassValue("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue))
|
||||
return bCryptErr
|
||||
}
|
||||
|
||||
// == private methods ==
|
||||
|
||||
func addUser(email string) error {
|
||||
var hasAdmin, err = logic.HasAdmin()
|
||||
if err != nil {
|
||||
logic.Log("error checking for existence of admin user during OAuth login for "+email+", user not added", 1)
|
||||
return err
|
||||
} // generate random password to adapt to current model
|
||||
var newPass, fetchErr = fetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return fetchErr
|
||||
}
|
||||
var newUser = models.User{
|
||||
UserName: email,
|
||||
Password: newPass,
|
||||
}
|
||||
if !hasAdmin { // must be first attempt, create an admin
|
||||
if newUser, err = logic.CreateAdmin(newUser); err != nil {
|
||||
logic.Log("error creating admin from user, "+email+", user not added", 1)
|
||||
} else {
|
||||
logic.Log("admin created from user, "+email+", was first user added", 0)
|
||||
}
|
||||
} else { // otherwise add to db as admin..?
|
||||
// TODO: add ability to add users with preemptive permissions
|
||||
newUser.IsAdmin = false
|
||||
if newUser, err = logic.CreateUser(newUser); err != nil {
|
||||
logic.Log("error creating user, "+email+", user not added", 1)
|
||||
} else {
|
||||
logic.Log("user created from, "+email+"", 0)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fetchPassValue(newValue string) (string, error) {
|
||||
|
||||
type valueHolder struct {
|
||||
Value string `json:"value" bson:"value"`
|
||||
}
|
||||
var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue))
|
||||
var newValueHolder = &valueHolder{
|
||||
Value: b64NewValue,
|
||||
}
|
||||
var data, marshalErr = json.Marshal(newValueHolder)
|
||||
if marshalErr != nil {
|
||||
return "", marshalErr
|
||||
}
|
||||
|
||||
var currentValue, err = logic.FetchAuthSecret(auth_key, string(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder)
|
||||
if unmarshErr != nil {
|
||||
return "", unmarshErr
|
||||
}
|
||||
|
||||
var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
|
||||
if b64Err != nil {
|
||||
logic.Log("could not decode pass", 0)
|
||||
return "", nil
|
||||
}
|
||||
return string(b64CurrentValue), nil
|
||||
}
|
126
auth/azure-ad.go
Normal file
126
auth/azure-ad.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/microsoft"
|
||||
)
|
||||
|
||||
var azure_ad_functions = map[string]interface{}{
|
||||
init_provider: initAzureAD,
|
||||
get_user_info: getAzureUserInfo,
|
||||
handle_callback: handleAzureCallback,
|
||||
handle_login: handleAzureLogin,
|
||||
verify_user: verifyAzureUser,
|
||||
}
|
||||
|
||||
type azureOauthUser struct {
|
||||
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
// == handle azure ad authentication here ==
|
||||
|
||||
func initAzureAD(redirectURL string, clientID string, clientSecret string) {
|
||||
auth_provider = &oauth2.Config{
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: []string{"User.Read"},
|
||||
Endpoint: microsoft.AzureADEndpoint(os.Getenv("AZURE_TENANT")),
|
||||
}
|
||||
}
|
||||
|
||||
func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
|
||||
oauth_state_string = logic.RandomString(16)
|
||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
} else if auth_provider == nil {
|
||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||
return
|
||||
}
|
||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var content, err = getAzureUserInfo(r.FormValue("state"), r.FormValue("code"))
|
||||
if err != nil {
|
||||
logic.Log("error when getting user info from azure: "+err.Error(), 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
_, err = logic.GetUser(content.UserPrincipalName)
|
||||
if err != nil { // user must not exists, so try to make one
|
||||
if err = addUser(content.UserPrincipalName); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var newPass, fetchErr = fetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
||||
// send a netmaker jwt token
|
||||
var authRequest = models.UserAuthParams{
|
||||
UserName: content.UserPrincipalName,
|
||||
Password: newPass,
|
||||
}
|
||||
|
||||
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
|
||||
if jwtErr != nil {
|
||||
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
|
||||
return
|
||||
}
|
||||
|
||||
logic.Log("completed azure OAuth sigin in for "+content.UserPrincipalName, 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.UserPrincipalName, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
|
||||
if state != oauth_state_string {
|
||||
return nil, fmt.Errorf("invalid oauth state")
|
||||
}
|
||||
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
|
||||
}
|
||||
var data []byte
|
||||
data, err = json.Marshal(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
|
||||
}
|
||||
var httpReq, reqErr = http.NewRequest("GET", "https://graph.microsoft.com/v1.0/me", nil)
|
||||
if reqErr != nil {
|
||||
return nil, fmt.Errorf("failed to create request to GitHub")
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
response, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
|
||||
}
|
||||
defer response.Body.Close()
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
|
||||
}
|
||||
var userInfo = &azureOauthUser{}
|
||||
if err = json.Unmarshal(contents, userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
|
||||
}
|
||||
userInfo.AccessToken = string(data)
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func verifyAzureUser(token *oauth2.Token) bool {
|
||||
return token.Valid()
|
||||
}
|
129
auth/github.go
Normal file
129
auth/github.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
)
|
||||
|
||||
var github_functions = map[string]interface{}{
|
||||
init_provider: initGithub,
|
||||
get_user_info: getGithubUserInfo,
|
||||
handle_callback: handleGithubCallback,
|
||||
handle_login: handleGithubLogin,
|
||||
verify_user: verifyGithubUser,
|
||||
}
|
||||
|
||||
type githubOauthUser struct {
|
||||
Login string `json:"login" bson:"login"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
// == handle github authentication here ==
|
||||
|
||||
func initGithub(redirectURL string, clientID string, clientSecret string) {
|
||||
auth_provider = &oauth2.Config{
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: []string{},
|
||||
Endpoint: github.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
|
||||
oauth_state_string = logic.RandomString(16)
|
||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?error=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
} else if auth_provider == nil {
|
||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||
return
|
||||
}
|
||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var content, err = getGithubUserInfo(r.URL.Query().Get("state"), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
logic.Log("error when getting user info from github: "+err.Error(), 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
_, err = logic.GetUser(content.Login)
|
||||
if err != nil { // user must not exist, so try to make one
|
||||
if err = addUser(content.Login); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var newPass, fetchErr = fetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
||||
// send a netmaker jwt token
|
||||
var authRequest = models.UserAuthParams{
|
||||
UserName: content.Login,
|
||||
Password: newPass,
|
||||
}
|
||||
|
||||
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
|
||||
if jwtErr != nil {
|
||||
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
|
||||
return
|
||||
}
|
||||
|
||||
logic.Log("completed github OAuth sigin in for "+content.Login, 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Login, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
|
||||
if state != oauth_state_string {
|
||||
return nil, fmt.Errorf("invalid OAuth state")
|
||||
}
|
||||
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
|
||||
}
|
||||
if !token.Valid() {
|
||||
return nil, fmt.Errorf("GitHub code exchange yielded invalid token")
|
||||
}
|
||||
var data []byte
|
||||
data, err = json.Marshal(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
|
||||
}
|
||||
var httpClient = &http.Client{}
|
||||
var httpReq, reqErr = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if reqErr != nil {
|
||||
return nil, fmt.Errorf("failed to create request to GitHub")
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "token "+token.AccessToken)
|
||||
response, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
|
||||
}
|
||||
defer response.Body.Close()
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
|
||||
}
|
||||
var userInfo = &githubOauthUser{}
|
||||
if err = json.Unmarshal(contents, userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
|
||||
}
|
||||
userInfo.AccessToken = string(data)
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func verifyGithubUser(token *oauth2.Token) bool {
|
||||
return token.Valid()
|
||||
}
|
120
auth/google.go
Normal file
120
auth/google.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var google_functions = map[string]interface{}{
|
||||
init_provider: initGoogle,
|
||||
get_user_info: getGoogleUserInfo,
|
||||
handle_callback: handleGoogleCallback,
|
||||
handle_login: handleGoogleLogin,
|
||||
verify_user: verifyGoogleUser,
|
||||
}
|
||||
|
||||
type googleOauthUser struct {
|
||||
Email string `json:"email" bson:"email"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
// == handle google authentication here ==
|
||||
|
||||
func initGoogle(redirectURL string, clientID string, clientSecret string) {
|
||||
auth_provider = &oauth2.Config{
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
oauth_state_string = logic.RandomString(16)
|
||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
} else if auth_provider == nil {
|
||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||
return
|
||||
}
|
||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var content, err = getGoogleUserInfo(r.FormValue("state"), r.FormValue("code"))
|
||||
if err != nil {
|
||||
logic.Log("error when getting user info from google: "+err.Error(), 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
_, err = logic.GetUser(content.Email)
|
||||
if err != nil { // user must not exists, so try to make one
|
||||
if err = addUser(content.Email); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
var newPass, fetchErr = fetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
||||
// send a netmaker jwt token
|
||||
var authRequest = models.UserAuthParams{
|
||||
UserName: content.Email,
|
||||
Password: newPass,
|
||||
}
|
||||
|
||||
var jwt, jwtErr = logic.VerifyAuthRequest(authRequest)
|
||||
if jwtErr != nil {
|
||||
logic.Log("could not parse jwt for user "+authRequest.UserName, 1)
|
||||
return
|
||||
}
|
||||
|
||||
logic.Log("completed google OAuth sigin in for "+content.Email, 1)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
|
||||
if state != oauth_state_string {
|
||||
return nil, fmt.Errorf("invalid OAuth state")
|
||||
}
|
||||
var token, err = auth_provider.Exchange(oauth2.NoContext, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("code exchange failed: %s", err.Error())
|
||||
}
|
||||
var data []byte
|
||||
data, err = json.Marshal(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert token to json: %s", err.Error())
|
||||
}
|
||||
response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting user info: %s", err.Error())
|
||||
}
|
||||
defer response.Body.Close()
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reading response body: %s", err.Error())
|
||||
}
|
||||
var userInfo = &googleOauthUser{}
|
||||
if err = json.Unmarshal(contents, userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed parsing email from response data: %s", err.Error())
|
||||
}
|
||||
userInfo.AccessToken = string(data)
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func verifyGoogleUser(token *oauth2.Token) bool {
|
||||
return token.Valid()
|
||||
}
|
|
@ -30,49 +30,52 @@ var Config *EnvironmentConfig
|
|||
// EnvironmentConfig :
|
||||
type EnvironmentConfig struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
SQL SQLConfig `yaml:"sql"`
|
||||
SQL SQLConfig `yaml:"sql"`
|
||||
}
|
||||
|
||||
// ServerConfig :
|
||||
type ServerConfig struct {
|
||||
CoreDNSAddr string `yaml:"corednsaddr"`
|
||||
APIConnString string `yaml:"apiconn"`
|
||||
APIHost string `yaml:"apihost"`
|
||||
APIPort string `yaml:"apiport"`
|
||||
GRPCConnString string `yaml:"grpcconn"`
|
||||
GRPCHost string `yaml:"grpchost"`
|
||||
GRPCPort string `yaml:"grpcport"`
|
||||
GRPCSecure string `yaml:"grpcsecure"`
|
||||
MasterKey string `yaml:"masterkey"`
|
||||
AllowedOrigin string `yaml:"allowedorigin"`
|
||||
NodeID string `yaml:"nodeid"`
|
||||
RestBackend string `yaml:"restbackend"`
|
||||
AgentBackend string `yaml:"agentbackend"`
|
||||
ClientMode string `yaml:"clientmode"`
|
||||
DNSMode string `yaml:"dnsmode"`
|
||||
SplitDNS string `yaml:"splitdns"`
|
||||
DisableRemoteIPCheck string `yaml:"disableremoteipcheck"`
|
||||
DisableDefaultNet string `yaml:"disabledefaultnet"`
|
||||
GRPCSSL string `yaml:"grpcssl"`
|
||||
Version string `yaml:"version"`
|
||||
SQLConn string `yaml:"sqlconn"`
|
||||
Platform string `yaml:"platform"`
|
||||
Database string `yaml:database`
|
||||
CheckinInterval string `yaml:checkininterval`
|
||||
DefaultNodeLimit int32 `yaml:"defaultnodelimit"`
|
||||
Verbosity int32 `yaml:"verbosity"`
|
||||
CoreDNSAddr string `yaml:"corednsaddr"`
|
||||
APIConnString string `yaml:"apiconn"`
|
||||
APIHost string `yaml:"apihost"`
|
||||
APIPort string `yaml:"apiport"`
|
||||
GRPCConnString string `yaml:"grpcconn"`
|
||||
GRPCHost string `yaml:"grpchost"`
|
||||
GRPCPort string `yaml:"grpcport"`
|
||||
GRPCSecure string `yaml:"grpcsecure"`
|
||||
MasterKey string `yaml:"masterkey"`
|
||||
AllowedOrigin string `yaml:"allowedorigin"`
|
||||
NodeID string `yaml:"nodeid"`
|
||||
RestBackend string `yaml:"restbackend"`
|
||||
AgentBackend string `yaml:"agentbackend"`
|
||||
ClientMode string `yaml:"clientmode"`
|
||||
DNSMode string `yaml:"dnsmode"`
|
||||
SplitDNS string `yaml:"splitdns"`
|
||||
DisableRemoteIPCheck string `yaml:"disableremoteipcheck"`
|
||||
DisableDefaultNet string `yaml:"disabledefaultnet"`
|
||||
GRPCSSL string `yaml:"grpcssl"`
|
||||
Version string `yaml:"version"`
|
||||
SQLConn string `yaml:"sqlconn"`
|
||||
Platform string `yaml:"platform"`
|
||||
Database string `yaml:database`
|
||||
CheckinInterval string `yaml:checkininterval`
|
||||
DefaultNodeLimit int32 `yaml:"defaultnodelimit"`
|
||||
Verbosity int32 `yaml:"verbosity"`
|
||||
ServerCheckinInterval int64 `yaml:"servercheckininterval"`
|
||||
AuthProvider string `yaml:"authprovider"`
|
||||
ClientID string `yaml:"clientid"`
|
||||
ClientSecret string `yaml:"clientsecret"`
|
||||
FrontendURL string `yaml:"frontendurl"`
|
||||
}
|
||||
|
||||
|
||||
// Generic SQL Config
|
||||
type SQLConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int32 `yaml:"port"`
|
||||
Host string `yaml:"host"`
|
||||
Port int32 `yaml:"port"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
DB string `yaml:"db"`
|
||||
SSLMode string `yaml:"sslmode"`
|
||||
DB string `yaml:"db"`
|
||||
SSLMode string `yaml:"sslmode"`
|
||||
}
|
||||
|
||||
//reading in the env file
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
|
@ -42,8 +43,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
|
|||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Println("REST Server successfully started on port " + port + " (REST)")
|
||||
logic.Log("REST Server successfully started on port "+port+" (REST)", 0)
|
||||
c := make(chan os.Signal)
|
||||
|
||||
// Relay os.Interrupt to our channel (os.Interrupt = CTRL+C)
|
||||
|
@ -55,7 +55,7 @@ func HandleRESTRequests(wg *sync.WaitGroup) {
|
|||
<-c
|
||||
|
||||
// After receiving CTRL+C Properly stop the server
|
||||
log.Println("Stopping the REST server...")
|
||||
logic.Log("Stopping the REST server...", 0)
|
||||
srv.Shutdown(context.TODO())
|
||||
log.Println("REST Server closed.")
|
||||
logic.Log("REST Server closed.", 0)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -413,17 +412,3 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
"Deleted extclient client "+params["clientid"]+" from network "+params["network"], 1)
|
||||
returnSuccessResponse(w, r, params["clientid"]+" deleted.")
|
||||
}
|
||||
|
||||
// StringWithCharset - returns a random string in a charset
|
||||
func StringWithCharset(length int, charset string) string {
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
var seededRand *rand.Rand = rand.New(
|
||||
rand.NewSource(time.Now().UnixNano()))
|
||||
|
|
|
@ -2,9 +2,9 @@ package controller
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
|
@ -48,7 +48,7 @@ func returnErrorResponse(response http.ResponseWriter, request *http.Request, er
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(errorMessage)
|
||||
logic.Log("processed request error: "+errorMessage.Message, 1)
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.WriteHeader(errorMessage.Code)
|
||||
response.Write(jsonResponse)
|
||||
|
|
|
@ -7,12 +7,12 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func userHandlers(r *mux.Router) {
|
||||
|
@ -21,11 +21,19 @@ func userHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/users/adm/createadmin", createAdmin).Methods("POST")
|
||||
r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods("POST")
|
||||
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(updateUser))).Methods("PUT")
|
||||
r.HandleFunc("/api/users/networks/{username}", authorizeUserAdm(http.HandlerFunc(updateUserNetworks))).Methods("PUT")
|
||||
r.HandleFunc("/api/users/{username}/adm", authorizeUserAdm(http.HandlerFunc(updateUserAdm))).Methods("PUT")
|
||||
r.HandleFunc("/api/users/{username}", authorizeUserAdm(http.HandlerFunc(createUser))).Methods("POST")
|
||||
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(deleteUser))).Methods("DELETE")
|
||||
r.HandleFunc("/api/users/{username}", authorizeUser(http.HandlerFunc(getUser))).Methods("GET")
|
||||
r.HandleFunc("/api/users", authorizeUserAdm(http.HandlerFunc(getUsers))).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods("GET")
|
||||
r.HandleFunc("/api/oauth/error", throwOauthError).Methods("GET")
|
||||
}
|
||||
|
||||
func throwOauthError(response http.ResponseWriter, request *http.Request) {
|
||||
returnErrorResponse(response, request, formatError(errors.New("No token returned"), "unauthorized"))
|
||||
}
|
||||
|
||||
// Node authenticates using its password and retrieves a JWT for authorization.
|
||||
|
@ -46,7 +54,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
if err != nil {
|
||||
returnErrorResponse(response, request, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -79,35 +87,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
|
|||
response.Write(successJSONResponse)
|
||||
}
|
||||
|
||||
// VerifyAuthRequest - verifies an auth request
|
||||
func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
|
||||
var result models.User
|
||||
if authRequest.UserName == "" {
|
||||
return "", errors.New("username can't be empty")
|
||||
} else if authRequest.Password == "" {
|
||||
return "", errors.New("password can't be empty")
|
||||
}
|
||||
//Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
|
||||
if err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &result); err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
|
||||
// compare password from request to stored password in database
|
||||
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
|
||||
// TODO: Consider a way of hashing the password client side before sending, or using certificates
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
|
||||
//Create a new JWT for the node
|
||||
tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin)
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// The middleware for most requests to the API
|
||||
// They all pass through here first
|
||||
// This will validate the JWT (or check for master token)
|
||||
|
@ -181,37 +160,11 @@ func ValidateUserToken(token string, user string, adminonly bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// HasAdmin - checks if server has an admin
|
||||
func HasAdmin() (bool, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return false, nil
|
||||
} else {
|
||||
return true, err
|
||||
|
||||
}
|
||||
}
|
||||
for _, value := range collection { // filter for isadmin true
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if user.IsAdmin {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func hasAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
hasadmin, err := HasAdmin()
|
||||
hasadmin, err := logic.HasAdmin()
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
|
@ -221,20 +174,6 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
}
|
||||
|
||||
// GetUser - gets a user
|
||||
func GetUser(username string) (models.ReturnUser, error) {
|
||||
|
||||
var user models.ReturnUser
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &user); err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
// GetUserInternal - gets an internal user
|
||||
func GetUserInternal(username string) (models.User, error) {
|
||||
|
||||
|
@ -249,30 +188,6 @@ func GetUserInternal(username string) (models.User, error) {
|
|||
return user, err
|
||||
}
|
||||
|
||||
// GetUsers - gets users
|
||||
func GetUsers() ([]models.ReturnUser, error) {
|
||||
|
||||
var users []models.ReturnUser
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
|
||||
var user models.ReturnUser
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue // get users
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// Get an individual node. Nothin fancy here folks.
|
||||
func getUser(w http.ResponseWriter, r *http.Request) {
|
||||
// set header.
|
||||
|
@ -280,7 +195,7 @@ func getUser(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var params = mux.Vars(r)
|
||||
usernameFetched := params["username"]
|
||||
user, err := GetUser(usernameFetched)
|
||||
user, err := logic.GetUser(usernameFetched)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -295,7 +210,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
|
|||
// set header.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
users, err := GetUsers()
|
||||
users, err := logic.GetUsers()
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -306,42 +221,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(users)
|
||||
}
|
||||
|
||||
// CreateUser - creates a user
|
||||
func CreateUser(user models.User) (models.User, error) {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err == nil {
|
||||
return models.User{}, errors.New("user exists")
|
||||
}
|
||||
err := ValidateUser("create", user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
user.Password = string(hash)
|
||||
|
||||
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
|
||||
|
||||
if tokenString == "" {
|
||||
// returnErrorResponse(w, r, errorResponse)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// connect db
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
func createAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
@ -349,7 +228,7 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
// get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&admin)
|
||||
|
||||
admin, err := CreateAdmin(admin)
|
||||
admin, err := logic.CreateAdmin(admin)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
|
@ -359,18 +238,6 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(admin)
|
||||
}
|
||||
|
||||
func CreateAdmin(admin models.User) (models.User, error) {
|
||||
hasadmin, err := HasAdmin()
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if hasadmin {
|
||||
return models.User{}, errors.New("admin user already exists")
|
||||
}
|
||||
admin.IsAdmin = true
|
||||
return CreateUser(admin)
|
||||
}
|
||||
|
||||
func createUser(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
|
@ -378,7 +245,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
|
|||
// get node from body of request
|
||||
_ = json.NewDecoder(r.Body).Decode(&user)
|
||||
|
||||
user, err := CreateUser(user)
|
||||
user, err := logic.CreateUser(user)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
|
@ -388,53 +255,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// UpdateUser - updates a given user
|
||||
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
|
||||
//check if user exists
|
||||
if _, err := GetUser(user.UserName); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
err := ValidateUser("update", userchange)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
queryUser := user.UserName
|
||||
|
||||
if userchange.UserName != "" {
|
||||
user.UserName = userchange.UserName
|
||||
}
|
||||
if len(userchange.Networks) > 0 {
|
||||
user.Networks = userchange.Networks
|
||||
}
|
||||
if userchange.Password != "" {
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
|
||||
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
userchange.Password = string(hash)
|
||||
|
||||
user.Password = userchange.Password
|
||||
}
|
||||
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
functions.PrintUserLog(models.NODE_SERVER_NAME, "updated user "+queryUser, 1)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
var user models.User
|
||||
|
@ -452,8 +273,40 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
err = logic.UpdateUserNetworks(userchange.Networks, userchange.IsAdmin, &user)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
functions.PrintUserLog(username, "status was updated", 1)
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
func updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
var user models.User
|
||||
// start here
|
||||
username := params["username"]
|
||||
user, err := GetUserInternal(username)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if auth.IsOauthUser(&user) == nil {
|
||||
returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user %s", username), "forbidden"))
|
||||
return
|
||||
}
|
||||
var userchange models.User
|
||||
// we decode our body request params
|
||||
err = json.NewDecoder(r.Body).Decode(&userchange)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
userchange.Networks = nil
|
||||
user, err = UpdateUser(userchange, user)
|
||||
user, err = logic.UpdateUser(userchange, user)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -473,6 +326,10 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if auth.IsOauthUser(&user) != nil {
|
||||
returnErrorResponse(w, r, formatError(fmt.Errorf("can not update user info for oauth user"), "forbidden"))
|
||||
return
|
||||
}
|
||||
var userchange models.User
|
||||
// we decode our body request params
|
||||
err = json.NewDecoder(r.Body).Decode(&userchange)
|
||||
|
@ -480,7 +337,7 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
|
|||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
user, err = UpdateUser(userchange, user)
|
||||
user, err = logic.UpdateUser(userchange, user)
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "badrequest"))
|
||||
return
|
||||
|
@ -489,20 +346,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
|
|||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) (bool, error) {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return false, errors.New("user does not exist")
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
// Set header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
@ -511,7 +354,7 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
var params = mux.Vars(r)
|
||||
|
||||
username := params["username"]
|
||||
success, err := DeleteUser(username)
|
||||
success, err := logic.DeleteUser(username)
|
||||
|
||||
if err != nil {
|
||||
returnErrorResponse(w, r, formatError(err, "internal"))
|
||||
|
@ -524,17 +367,3 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
functions.PrintUserLog(username, "was deleted", 1)
|
||||
json.NewEncoder(w).Encode(params["username"] + " deleted.")
|
||||
}
|
||||
|
||||
// ValidateUser - validates a user model
|
||||
func ValidateUser(operation string, user models.User) error {
|
||||
|
||||
v := validator.New()
|
||||
err := v.Struct(user)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
fmt.Println(e)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,52 +4,53 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func deleteAllUsers() {
|
||||
users, _ := GetUsers()
|
||||
users, _ := logic.GetUsers()
|
||||
for _, user := range users {
|
||||
DeleteUser(user.UserName)
|
||||
logic.DeleteUser(user.UserName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAdmin(t *testing.T) {
|
||||
//delete all current users
|
||||
database.InitializeDatabase()
|
||||
users, _ := GetUsers()
|
||||
users, _ := logic.GetUsers()
|
||||
for _, user := range users {
|
||||
success, err := DeleteUser(user.UserName)
|
||||
success, err := logic.DeleteUser(user.UserName)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, success)
|
||||
}
|
||||
t.Run("NoUser", func(t *testing.T) {
|
||||
found, err := HasAdmin()
|
||||
found, err := logic.HasAdmin()
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, found)
|
||||
})
|
||||
t.Run("No admin user", func(t *testing.T) {
|
||||
var user = models.User{"noadmin", "password", nil, false}
|
||||
_, err := CreateUser(user)
|
||||
_, err := logic.CreateUser(user)
|
||||
assert.Nil(t, err)
|
||||
found, err := HasAdmin()
|
||||
found, err := logic.HasAdmin()
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, found)
|
||||
})
|
||||
t.Run("admin user", func(t *testing.T) {
|
||||
var user = models.User{"admin", "password", nil, true}
|
||||
_, err := CreateUser(user)
|
||||
_, err := logic.CreateUser(user)
|
||||
assert.Nil(t, err)
|
||||
found, err := HasAdmin()
|
||||
found, err := logic.HasAdmin()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, found)
|
||||
})
|
||||
t.Run("multiple admins", func(t *testing.T) {
|
||||
var user = models.User{"admin1", "password", nil, true}
|
||||
_, err := CreateUser(user)
|
||||
_, err := logic.CreateUser(user)
|
||||
assert.Nil(t, err)
|
||||
found, err := HasAdmin()
|
||||
found, err := logic.HasAdmin()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, found)
|
||||
})
|
||||
|
@ -60,12 +61,12 @@ func TestCreateUser(t *testing.T) {
|
|||
deleteAllUsers()
|
||||
user := models.User{"admin", "password", nil, true}
|
||||
t.Run("NoUser", func(t *testing.T) {
|
||||
admin, err := CreateUser(user)
|
||||
admin, err := logic.CreateUser(user)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, admin.UserName)
|
||||
})
|
||||
t.Run("UserExists", func(t *testing.T) {
|
||||
_, err := CreateUser(user)
|
||||
_, err := logic.CreateUser(user)
|
||||
assert.NotNil(t, err)
|
||||
assert.EqualError(t, err, "user exists")
|
||||
})
|
||||
|
@ -78,14 +79,14 @@ func TestCreateAdmin(t *testing.T) {
|
|||
t.Run("NoAdmin", func(t *testing.T) {
|
||||
user.UserName = "admin"
|
||||
user.Password = "password"
|
||||
admin, err := CreateAdmin(user)
|
||||
admin, err := logic.CreateAdmin(user)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, admin.UserName)
|
||||
})
|
||||
t.Run("AdminExists", func(t *testing.T) {
|
||||
user.UserName = "admin2"
|
||||
user.Password = "password1"
|
||||
admin, err := CreateAdmin(user)
|
||||
admin, err := logic.CreateAdmin(user)
|
||||
assert.EqualError(t, err, "admin user already exists")
|
||||
assert.Equal(t, admin, models.User{})
|
||||
})
|
||||
|
@ -95,14 +96,14 @@ func TestDeleteUser(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
deleteAllUsers()
|
||||
t.Run("NonExistent User", func(t *testing.T) {
|
||||
deleted, err := DeleteUser("admin")
|
||||
deleted, err := logic.DeleteUser("admin")
|
||||
assert.EqualError(t, err, "user does not exist")
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
t.Run("Existing User", func(t *testing.T) {
|
||||
user := models.User{"admin", "password", nil, true}
|
||||
CreateUser(user)
|
||||
deleted, err := DeleteUser("admin")
|
||||
logic.CreateUser(user)
|
||||
deleted, err := logic.DeleteUser("admin")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, deleted)
|
||||
})
|
||||
|
@ -114,44 +115,44 @@ func TestValidateUser(t *testing.T) {
|
|||
t.Run("Valid Create", func(t *testing.T) {
|
||||
user.UserName = "admin"
|
||||
user.Password = "validpass"
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("Valid Update", func(t *testing.T) {
|
||||
user.UserName = "admin"
|
||||
user.Password = "password"
|
||||
err := ValidateUser("update", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("Invalid UserName", func(t *testing.T) {
|
||||
t.Skip()
|
||||
user.UserName = "*invalid"
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.Error(t, err)
|
||||
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
|
||||
})
|
||||
t.Run("Short UserName", func(t *testing.T) {
|
||||
t.Skip()
|
||||
user.UserName = "1"
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.NotNil(t, err)
|
||||
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
|
||||
})
|
||||
t.Run("Empty UserName", func(t *testing.T) {
|
||||
t.Skip()
|
||||
user.UserName = ""
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.EqualError(t, err, "some string")
|
||||
//assert.Contains(t, err.Error(), "Field validation for 'UserName' failed")
|
||||
})
|
||||
t.Run("EmptyPassword", func(t *testing.T) {
|
||||
user.Password = ""
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'required' tag")
|
||||
})
|
||||
t.Run("ShortPassword", func(t *testing.T) {
|
||||
user.Password = "123"
|
||||
err := ValidateUser("create", user)
|
||||
err := logic.ValidateUser(user)
|
||||
assert.EqualError(t, err, "Key: 'User.Password' Error:Field validation for 'Password' failed on the 'min' tag")
|
||||
})
|
||||
}
|
||||
|
@ -160,14 +161,14 @@ func TestGetUser(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
deleteAllUsers()
|
||||
t.Run("NonExistantUser", func(t *testing.T) {
|
||||
admin, err := GetUser("admin")
|
||||
admin, err := logic.GetUser("admin")
|
||||
assert.EqualError(t, err, "could not find any records")
|
||||
assert.Equal(t, "", admin.UserName)
|
||||
})
|
||||
t.Run("UserExisits", func(t *testing.T) {
|
||||
user := models.User{"admin", "password", nil, true}
|
||||
CreateUser(user)
|
||||
admin, err := GetUser("admin")
|
||||
logic.CreateUser(user)
|
||||
admin, err := logic.GetUser("admin")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, admin.UserName)
|
||||
})
|
||||
|
@ -183,7 +184,7 @@ func TestGetUserInternal(t *testing.T) {
|
|||
})
|
||||
t.Run("UserExisits", func(t *testing.T) {
|
||||
user := models.User{"admin", "password", nil, true}
|
||||
CreateUser(user)
|
||||
logic.CreateUser(user)
|
||||
admin, err := GetUserInternal("admin")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, admin.UserName)
|
||||
|
@ -194,21 +195,21 @@ func TestGetUsers(t *testing.T) {
|
|||
database.InitializeDatabase()
|
||||
deleteAllUsers()
|
||||
t.Run("NonExistantUser", func(t *testing.T) {
|
||||
admin, err := GetUsers()
|
||||
admin, err := logic.GetUsers()
|
||||
assert.EqualError(t, err, "could not find any records")
|
||||
assert.Equal(t, []models.ReturnUser(nil), admin)
|
||||
})
|
||||
t.Run("UserExisits", func(t *testing.T) {
|
||||
user := models.User{"admin", "password", nil, true}
|
||||
CreateUser(user)
|
||||
admins, err := GetUsers()
|
||||
logic.CreateUser(user)
|
||||
admins, err := logic.GetUsers()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.UserName, admins[0].UserName)
|
||||
})
|
||||
t.Run("MulipleUsers", func(t *testing.T) {
|
||||
user := models.User{"user", "password", nil, true}
|
||||
CreateUser(user)
|
||||
admins, err := GetUsers()
|
||||
logic.CreateUser(user)
|
||||
admins, err := logic.GetUsers()
|
||||
assert.Nil(t, err)
|
||||
for _, u := range admins {
|
||||
if u.UserName == "admin" {
|
||||
|
@ -227,14 +228,14 @@ func TestUpdateUser(t *testing.T) {
|
|||
user := models.User{"admin", "password", nil, true}
|
||||
newuser := models.User{"hello", "world", []string{"wirecat, netmaker"}, true}
|
||||
t.Run("NonExistantUser", func(t *testing.T) {
|
||||
admin, err := UpdateUser(newuser, user)
|
||||
admin, err := logic.UpdateUser(newuser, user)
|
||||
assert.EqualError(t, err, "could not find any records")
|
||||
assert.Equal(t, "", admin.UserName)
|
||||
})
|
||||
|
||||
t.Run("UserExists", func(t *testing.T) {
|
||||
CreateUser(user)
|
||||
admin, err := UpdateUser(newuser, user)
|
||||
logic.CreateUser(user)
|
||||
admin, err := logic.UpdateUser(newuser, user)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, newuser.UserName, admin.UserName)
|
||||
})
|
||||
|
@ -271,43 +272,43 @@ func TestVerifyAuthRequest(t *testing.T) {
|
|||
t.Run("EmptyUserName", func(t *testing.T) {
|
||||
authRequest.UserName = ""
|
||||
authRequest.Password = "Password"
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.Equal(t, "", jwt)
|
||||
assert.EqualError(t, err, "username can't be empty")
|
||||
})
|
||||
t.Run("EmptyPassword", func(t *testing.T) {
|
||||
authRequest.UserName = "admin"
|
||||
authRequest.Password = ""
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.Equal(t, "", jwt)
|
||||
assert.EqualError(t, err, "password can't be empty")
|
||||
})
|
||||
t.Run("NonExistantUser", func(t *testing.T) {
|
||||
authRequest.UserName = "admin"
|
||||
authRequest.Password = "password"
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.Equal(t, "", jwt)
|
||||
assert.EqualError(t, err, "incorrect credentials")
|
||||
})
|
||||
t.Run("Non-Admin", func(t *testing.T) {
|
||||
user := models.User{"nonadmin", "somepass", nil, false}
|
||||
CreateUser(user)
|
||||
logic.CreateUser(user)
|
||||
authRequest := models.UserAuthParams{"nonadmin", "somepass"}
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.NotNil(t, jwt)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
user := models.User{"admin", "password", nil, false}
|
||||
CreateUser(user)
|
||||
logic.CreateUser(user)
|
||||
authRequest := models.UserAuthParams{"admin", "badpass"}
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.Equal(t, "", jwt)
|
||||
assert.EqualError(t, err, "incorrect credentials")
|
||||
})
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
authRequest := models.UserAuthParams{"admin", "password"}
|
||||
jwt, err := VerifyAuthRequest(authRequest)
|
||||
jwt, err := logic.VerifyAuthRequest(authRequest)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, jwt)
|
||||
})
|
||||
|
|
|
@ -39,6 +39,9 @@ const SERVERCONF_TABLE_NAME = "serverconf"
|
|||
// DATABASE_FILENAME - database file name
|
||||
const DATABASE_FILENAME = "netmaker.db"
|
||||
|
||||
// GENERATED_TABLE_NAME - stores server generated k/v
|
||||
const GENERATED_TABLE_NAME = "generated"
|
||||
|
||||
// == ERROR CONSTS ==
|
||||
|
||||
// NO_RECORD - no singular result found
|
||||
|
@ -87,11 +90,11 @@ func getCurrentDB() map[string]interface{} {
|
|||
}
|
||||
|
||||
func InitializeDatabase() error {
|
||||
log.Println("connecting to", servercfg.GetDB())
|
||||
log.Println("[netmaker] connecting to", servercfg.GetDB())
|
||||
tperiod := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil {
|
||||
log.Println("unable to connect to db, retrying . . .")
|
||||
log.Println("[netmaker] unable to connect to db, retrying . . .")
|
||||
if time.Now().After(tperiod) {
|
||||
return err
|
||||
}
|
||||
|
@ -114,6 +117,7 @@ func createTables() {
|
|||
createTable(INT_CLIENTS_TABLE_NAME)
|
||||
createTable(PEERS_TABLE_NAME)
|
||||
createTable(SERVERCONF_TABLE_NAME)
|
||||
createTable(GENERATED_TABLE_NAME)
|
||||
}
|
||||
|
||||
func createTable(tableName string) error {
|
||||
|
|
1
go.mod
1
go.mod
|
@ -17,6 +17,7 @@ require (
|
|||
github.com/urfave/cli/v2 v2.3.0
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
|
||||
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect
|
||||
golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e // indirect
|
||||
golang.org/x/text v0.3.7-0.20210524175448-3115f89c4b99 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 // indirect
|
||||
|
|
268
logic/auth.go
Normal file
268
logic/auth.go
Normal file
|
@ -0,0 +1,268 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HasAdmin - checks if server has an admin
|
||||
func HasAdmin() (bool, error) {
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return false, nil
|
||||
} else {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
for _, value := range collection { // filter for isadmin true
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if user.IsAdmin {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// GetUser - gets a user
|
||||
func GetUser(username string) (models.ReturnUser, error) {
|
||||
|
||||
var user models.ReturnUser
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &user); err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
// GetUsers - gets users
|
||||
func GetUsers() ([]models.ReturnUser, error) {
|
||||
|
||||
var users []models.ReturnUser
|
||||
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
|
||||
var user models.ReturnUser
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue // get users
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// CreateUser - creates a user
|
||||
func CreateUser(user models.User) (models.User, error) {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err == nil {
|
||||
return models.User{}, errors.New("user exists")
|
||||
}
|
||||
var err = ValidateUser(user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
user.Password = string(hash)
|
||||
|
||||
tokenString, _ := functions.CreateUserJWT(user.UserName, user.Networks, user.IsAdmin)
|
||||
|
||||
if tokenString == "" {
|
||||
// returnErrorResponse(w, r, errorResponse)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// connect db
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
// CreateAdmin - creates an admin user
|
||||
func CreateAdmin(admin models.User) (models.User, error) {
|
||||
hasadmin, err := HasAdmin()
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if hasadmin {
|
||||
return models.User{}, errors.New("admin user already exists")
|
||||
}
|
||||
admin.IsAdmin = true
|
||||
return CreateUser(admin)
|
||||
}
|
||||
|
||||
// VerifyAuthRequest - verifies an auth request
|
||||
func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
|
||||
var result models.User
|
||||
if authRequest.UserName == "" {
|
||||
return "", errors.New("username can't be empty")
|
||||
} else if authRequest.Password == "" {
|
||||
return "", errors.New("password can't be empty")
|
||||
}
|
||||
//Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
|
||||
if err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &result); err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
|
||||
// compare password from request to stored password in database
|
||||
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
|
||||
// TODO: Consider a way of hashing the password client side before sending, or using certificates
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
|
||||
//Create a new JWT for the node
|
||||
tokenString, _ := functions.CreateUserJWT(authRequest.UserName, result.Networks, result.IsAdmin)
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// UpdateUserNetworks - updates the networks of a given user
|
||||
func UpdateUserNetworks(newNetworks []string, isadmin bool, currentUser *models.User) error {
|
||||
// check if user exists
|
||||
if returnedUser, err := GetUser(currentUser.UserName); err != nil {
|
||||
return err
|
||||
} else if returnedUser.IsAdmin {
|
||||
return fmt.Errorf("can not make changes to an admin user, attempted to change %s", returnedUser.UserName)
|
||||
}
|
||||
if isadmin {
|
||||
currentUser.IsAdmin = true
|
||||
currentUser.Networks = nil
|
||||
} else {
|
||||
currentUser.Networks = newNetworks
|
||||
}
|
||||
|
||||
data, err := json.Marshal(currentUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = database.Insert(currentUser.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUser - updates a given user
|
||||
func UpdateUser(userchange models.User, user models.User) (models.User, error) {
|
||||
//check if user exists
|
||||
if _, err := GetUser(user.UserName); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
err := ValidateUser(userchange)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
queryUser := user.UserName
|
||||
|
||||
if userchange.UserName != "" {
|
||||
user.UserName = userchange.UserName
|
||||
}
|
||||
if len(userchange.Networks) > 0 {
|
||||
user.Networks = userchange.Networks
|
||||
}
|
||||
if userchange.Password != "" {
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
|
||||
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
// set password to encrypted password
|
||||
userchange.Password = string(hash)
|
||||
|
||||
user.Password = userchange.Password
|
||||
}
|
||||
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return models.User{}, err
|
||||
}
|
||||
Log("updated user "+queryUser, 1)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ValidateUser - validates a user model
|
||||
func ValidateUser(user models.User) error {
|
||||
|
||||
v := validator.New()
|
||||
err := v.Struct(user)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
Log(e.Error(), 2)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) (bool, error) {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return false, errors.New("user does not exist")
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FetchAuthSecret - manages secrets for oauth
|
||||
func FetchAuthSecret(key string, secret string) (string, error) {
|
||||
var record, err = database.FetchRecord(database.GENERATED_TABLE_NAME, key)
|
||||
if err != nil {
|
||||
if err = database.Insert(key, secret, database.GENERATED_TABLE_NAME); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return secret, nil
|
||||
}
|
||||
}
|
||||
return record, nil
|
||||
}
|
|
@ -5,10 +5,17 @@ import (
|
|||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
)
|
||||
|
||||
// CheckNetworkExists - checks i a network exists for this netmaker instance
|
||||
func CheckNetworkExists(network string) bool {
|
||||
var _, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, network)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GetLocalIP - gets the local ip
|
||||
func GetLocalIP(node models.Node) string {
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -278,6 +279,19 @@ func GetPeersList(networkName string, excludeRelayed bool, relayedNodeAddr strin
|
|||
return peers, err
|
||||
}
|
||||
|
||||
// RandomString - returns a random string in a charset
|
||||
func RandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func setPeerInfo(node models.Node) models.Node {
|
||||
var peer models.Node
|
||||
peer.RelayAddrs = node.RelayAddrs
|
||||
|
@ -303,7 +317,7 @@ func setPeerInfo(node models.Node) models.Node {
|
|||
|
||||
func Log(message string, loglevel int) {
|
||||
log.SetFlags(log.Flags() &^ (log.Llongfile | log.Lshortfile))
|
||||
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() != 0 {
|
||||
if int32(loglevel) <= servercfg.GetVerbose() && servercfg.GetVerbose() >= 0 {
|
||||
log.Println("[netmaker] " + message)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,10 +62,11 @@ func setWGConfig(node models.Node, network string, peerupdate bool) error {
|
|||
var iface string
|
||||
iface = node.Interface
|
||||
err = setServerPeers(iface, node.PersistentKeepalive, peers)
|
||||
Log("updated peers on server "+node.Name, 2)
|
||||
} else {
|
||||
err = initWireguard(&node, privkey, peers, hasGateway, gateways)
|
||||
Log("finished setting wg config on server "+node.Name, 3)
|
||||
}
|
||||
Log("finished setting wg config on server "+node.Name, 1)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
39
main.go
39
main.go
|
@ -13,11 +13,13 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/dnslogic"
|
||||
"github.com/gravitl/netmaker/functions"
|
||||
nodepb "github.com/gravitl/netmaker/grpc"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
|
@ -35,20 +37,29 @@ func main() {
|
|||
|
||||
func initialize() { // Client Mode Prereq Check
|
||||
var err error
|
||||
|
||||
if err = database.InitializeDatabase(); err != nil {
|
||||
log.Println("Error connecting to database.")
|
||||
logic.Log("Error connecting to database", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Println("database successfully connected.")
|
||||
logic.Log("database successfully connected", 0)
|
||||
|
||||
var authProvider = auth.InitializeAuthProvider()
|
||||
if authProvider != "" {
|
||||
logic.Log("OAuth provider, "+authProvider+", initialized", 0)
|
||||
} else {
|
||||
logic.Log("no OAuth provider found or not configured, continuing without OAuth", 0)
|
||||
}
|
||||
|
||||
if servercfg.IsClientMode() != "off" {
|
||||
output, err := ncutils.RunCmd("id -u", true)
|
||||
if err != nil {
|
||||
log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.")
|
||||
logic.Log("Error running 'id -u' for prereq check. Please investigate or disable client mode.", 0)
|
||||
log.Fatal(output, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(string(output[:len(output)-1]))
|
||||
if err != nil {
|
||||
log.Println("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.")
|
||||
logic.Log("Error retrieving uid from 'id -u' for prereq check. Please investigate or disable client mode.", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
if uid != 0 {
|
||||
|
@ -74,7 +85,7 @@ func startControllers() {
|
|||
if !(servercfg.DisableRemoteIPCheck()) && servercfg.GetGRPCHost() == "127.0.0.1" {
|
||||
err := servercfg.SetHost()
|
||||
if err != nil {
|
||||
log.Println("Unable to Set host. Exiting...")
|
||||
logic.Log("Unable to Set host. Exiting...", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +101,7 @@ func startControllers() {
|
|||
if servercfg.IsDNSMode() {
|
||||
err := dnslogic.SetDNS()
|
||||
if err != nil {
|
||||
log.Println("error occurred initializing DNS:", err)
|
||||
logic.Log("error occurred initializing DNS: "+err.Error(), 0)
|
||||
}
|
||||
}
|
||||
//Run Rest Server
|
||||
|
@ -98,7 +109,7 @@ func startControllers() {
|
|||
if !servercfg.DisableRemoteIPCheck() && servercfg.GetAPIHost() == "127.0.0.1" {
|
||||
err := servercfg.SetHost()
|
||||
if err != nil {
|
||||
log.Println("Unable to Set host. Exiting...")
|
||||
logic.Log("Unable to Set host. Exiting...", 0)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -106,11 +117,11 @@ func startControllers() {
|
|||
controller.HandleRESTRequests(&waitnetwork)
|
||||
}
|
||||
if !servercfg.IsAgentBackend() && !servercfg.IsRestBackend() {
|
||||
log.Println("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.")
|
||||
logic.Log("No Server Mode selected, so nothing is being served! Set either Agent mode (AGENT_BACKEND) or Rest mode (REST_BACKEND) to 'true'.", 0)
|
||||
}
|
||||
|
||||
waitnetwork.Wait()
|
||||
log.Println("[netmaker] exiting")
|
||||
logic.Log("exiting", 0)
|
||||
}
|
||||
|
||||
func runClient(wg *sync.WaitGroup) {
|
||||
|
@ -139,7 +150,7 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
listener, err := net.Listen("tcp", ":"+grpcport)
|
||||
// Handle errors if any
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to listen on port "+grpcport+", error: %v", err)
|
||||
log.Fatalf("[netmaker] Unable to listen on port "+grpcport+", error: %v", err)
|
||||
}
|
||||
|
||||
s := grpc.NewServer(
|
||||
|
@ -157,7 +168,7 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
log.Fatalf("Failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
log.Println("Agent Server successfully started on port " + grpcport + " (gRPC)")
|
||||
logic.Log("Agent Server successfully started on port "+grpcport+" (gRPC)", 0)
|
||||
|
||||
// Right way to stop the server using a SHUTDOWN HOOK
|
||||
// Create a channel to receive OS signals
|
||||
|
@ -172,11 +183,11 @@ func runGRPC(wg *sync.WaitGroup) {
|
|||
<-c
|
||||
|
||||
// After receiving CTRL+C Properly stop the server
|
||||
log.Println("Stopping the Agent server...")
|
||||
logic.Log("Stopping the Agent server...", 0)
|
||||
s.Stop()
|
||||
listener.Close()
|
||||
log.Println("Agent server closed..")
|
||||
log.Println("Closed DB connection.")
|
||||
logic.Log("Agent server closed..", 0)
|
||||
logic.Log("Closed DB connection.", 0)
|
||||
}
|
||||
|
||||
func authServerUnaryInterceptor() grpc.ServerOption {
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
nodepb "github.com/gravitl/netmaker/grpc"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/auth"
|
||||
|
@ -18,6 +16,8 @@ import (
|
|||
"github.com/gravitl/netmaker/netclient/wireguard"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"log"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// JoinNetwork - helps a client join a network
|
||||
|
@ -84,8 +84,8 @@ func JoinNetwork(cfg config.ClientConfig, privateKey string) error {
|
|||
if ncutils.IsLinux() {
|
||||
_, err := exec.LookPath("resolvectl")
|
||||
if err != nil {
|
||||
ncutils.PrintLog("resolvectl not present",2)
|
||||
ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management",2)
|
||||
ncutils.PrintLog("resolvectl not present", 2)
|
||||
ncutils.PrintLog("unable to configure DNS automatically, disabling automated DNS management", 2)
|
||||
cfg.Node.DNSOn = "no"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,9 +155,9 @@ func parsePeers(keepalive int32, peers []wgtypes.PeerConfig) (string, error) {
|
|||
if keepalive <= 0 {
|
||||
keepalive = 20
|
||||
}
|
||||
|
||||
|
||||
for _, peer := range peers {
|
||||
endpointString := ""
|
||||
endpointString := ""
|
||||
if peer.Endpoint != nil && peer.Endpoint.String() != "" {
|
||||
endpointString += "Endpoint = " + peer.Endpoint.String()
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/config"
|
||||
)
|
||||
|
@ -65,8 +66,26 @@ func GetServerConfig() config.ServerConfig {
|
|||
cfg.Database = GetDB()
|
||||
cfg.Platform = GetPlatform()
|
||||
cfg.Version = GetVersion()
|
||||
|
||||
// == auth config ==
|
||||
var authInfo = GetAuthProviderInfo()
|
||||
cfg.AuthProvider = authInfo[0]
|
||||
cfg.ClientID = authInfo[1]
|
||||
cfg.ClientSecret = authInfo[2]
|
||||
cfg.FrontendURL = GetFrontendURL()
|
||||
|
||||
return cfg
|
||||
}
|
||||
func GetFrontendURL() string {
|
||||
var frontend = ""
|
||||
if os.Getenv("FRONTEND_URL") != "" {
|
||||
frontend = os.Getenv("FRONTEND_URL")
|
||||
} else if config.Config.Server.FrontendURL != "" {
|
||||
frontend = config.Config.Server.FrontendURL
|
||||
}
|
||||
return frontend
|
||||
}
|
||||
|
||||
func GetAPIConnString() string {
|
||||
conn := ""
|
||||
if os.Getenv("SERVER_API_CONN_STRING") != "" {
|
||||
|
@ -77,7 +96,7 @@ func GetAPIConnString() string {
|
|||
return conn
|
||||
}
|
||||
func GetVersion() string {
|
||||
version := "0.8.4"
|
||||
version := "0.8.5"
|
||||
if config.Config.Server.Version != "" {
|
||||
version = config.Config.Server.Version
|
||||
}
|
||||
|
@ -398,6 +417,25 @@ func GetServerCheckinInterval() int64 {
|
|||
return t
|
||||
}
|
||||
|
||||
// GetAuthProviderInfo = gets the oauth provider info
|
||||
func GetAuthProviderInfo() []string {
|
||||
var authProvider = ""
|
||||
if os.Getenv("AUTH_PROVIDER") != "" && os.Getenv("CLIENT_ID") != "" && os.Getenv("CLIENT_SECRET") != "" {
|
||||
authProvider = strings.ToLower(os.Getenv("AUTH_PROVIDER"))
|
||||
if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
|
||||
return []string{authProvider, os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET")}
|
||||
} else {
|
||||
authProvider = ""
|
||||
}
|
||||
} else if config.Config.Server.AuthProvider != "" && config.Config.Server.ClientID != "" && config.Config.Server.ClientSecret != "" {
|
||||
authProvider = strings.ToLower(config.Config.Server.AuthProvider)
|
||||
if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
|
||||
return []string{authProvider, config.Config.Server.ClientID, config.Config.Server.ClientSecret}
|
||||
}
|
||||
}
|
||||
return []string{"", "", ""}
|
||||
}
|
||||
|
||||
// GetMacAddr - get's mac address
|
||||
func getMacAddr() string {
|
||||
ifas, err := net.Interfaces()
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package servercfg
|
||||
|
||||
import (
|
||||
"os"
|
||||
"github.com/gravitl/netmaker/config"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue