mirror of
https://github.com/gravitl/netmaker.git
synced 2024-09-20 15:26:04 +08:00
NET-1134:move oauth from CE build block to pro (#2919)
* move oauth from CE build block to pro * move oauth code and api handler under pro * move common func back to auth from pro/auth * change log level to Info for information logs * fix import issue
This commit is contained in:
parent
7eb1cf49e0
commit
da11dc8a87
341
auth/auth.go
341
auth/auth.go
|
@ -3,156 +3,25 @@ package auth
|
|||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
// == consts ==
|
||||
const (
|
||||
init_provider = "initprovider"
|
||||
get_user_info = "getuserinfo"
|
||||
handle_callback = "handlecallback"
|
||||
handle_login = "handlelogin"
|
||||
google_provider_name = "google"
|
||||
azure_ad_provider_name = "azure-ad"
|
||||
github_provider_name = "github"
|
||||
oidc_provider_name = "oidc"
|
||||
verify_user = "verifyuser"
|
||||
user_signin_length = 16
|
||||
node_signin_length = 64
|
||||
headless_signin_length = 32
|
||||
node_signin_length = 64
|
||||
)
|
||||
|
||||
// OAuthUser - generic OAuth strategy user
|
||||
type OAuthUser struct {
|
||||
Name string `json:"name" bson:"name"`
|
||||
Email string `json:"email" bson:"email"`
|
||||
Login string `json:"login" bson:"login"`
|
||||
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
var (
|
||||
auth_provider *oauth2.Config
|
||||
upgrader = websocket.Upgrader{}
|
||||
)
|
||||
|
||||
func getCurrentAuthFunctions() map[string]interface{} {
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var authProvider = authInfo[0]
|
||||
switch authProvider {
|
||||
case google_provider_name:
|
||||
return google_functions
|
||||
case azure_ad_provider_name:
|
||||
return azure_ad_functions
|
||||
case github_provider_name:
|
||||
return github_functions
|
||||
case oidc_provider_name:
|
||||
return oidc_functions
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeAuthProvider - initializes the auth provider if any is present
|
||||
func InitializeAuthProvider() string {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return ""
|
||||
}
|
||||
logger.Log(0, "setting oauth secret")
|
||||
var err = logic.SetAuthSecret(logic.RandomString(64))
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to set auth_secret", err.Error())
|
||||
}
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var serverConn = servercfg.GetAPIHost()
|
||||
if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") {
|
||||
serverConn = "http://" + serverConn
|
||||
logger.Log(1, "localhost OAuth detected, proceeding with insecure http redirect: (", serverConn, ")")
|
||||
} else {
|
||||
serverConn = "https://" + serverConn
|
||||
logger.Log(1, "external OAuth detected, proceeding with https redirect: ("+serverConn+")")
|
||||
}
|
||||
|
||||
if authInfo[0] == "oidc" {
|
||||
functions[init_provider].(func(string, string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2], authInfo[3])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
functions[init_provider].(func(string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
// HandleAuthCallback - handles oauth callback
|
||||
// Note: not included in API reference as part of the OAuth process itself.
|
||||
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if auth_provider == nil {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
state, _ := getStateAndCode(r)
|
||||
_, err := netcache.Get(state) // if in netcache proceeed with node registration login
|
||||
if err == nil || errors.Is(err, netcache.ErrExpired) {
|
||||
switch len(state) {
|
||||
case node_signin_length:
|
||||
logger.Log(1, "proceeding with host SSO callback")
|
||||
HandleHostSSOCallback(w, r)
|
||||
case headless_signin_length:
|
||||
logger.Log(1, "proceeding with headless SSO callback")
|
||||
HandleHeadlessSSOCallback(w, r)
|
||||
default:
|
||||
logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
|
||||
}
|
||||
} else { // handle normal login
|
||||
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// swagger:route GET /api/oauth/login nodes HandleAuthLogin
|
||||
//
|
||||
// Handles OAuth login.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
// Responses:
|
||||
// 200: okResponse
|
||||
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if auth_provider == nil {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
if servercfg.GetFrontendURL() == "" {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
|
||||
// IsOauthUser - returns
|
||||
func IsOauthUser(user *models.User) error {
|
||||
var currentValue, err = FetchPassValue("")
|
||||
|
@ -163,116 +32,6 @@ func IsOauthUser(user *models.User) error {
|
|||
return bCryptErr
|
||||
}
|
||||
|
||||
// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
|
||||
func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
logger.Log(0, "failed to establish web-socket connection during headless sign-in")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := &netcache.CValue{User: "", Pass: ""}
|
||||
stateStr := logic.RandomString(headless_signin_length)
|
||||
if err = netcache.Set(stateStr, req); err != nil {
|
||||
logger.Log(0, "Failed to process sso request -", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
timeout := make(chan bool, 1)
|
||||
answer := make(chan string, 1)
|
||||
defer close(answer)
|
||||
defer close(timeout)
|
||||
|
||||
if auth_provider == nil {
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
|
||||
if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
cachedReq, err := netcache.Get(stateStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "expired") {
|
||||
logger.Log(0, "timeout occurred while waiting for SSO")
|
||||
timeout <- true
|
||||
break
|
||||
}
|
||||
continue
|
||||
} else if cachedReq.Pass != "" {
|
||||
logger.Log(0, "SSO process completed for user ", cachedReq.User)
|
||||
answer <- cachedReq.Pass
|
||||
break
|
||||
}
|
||||
time.Sleep(500) // try it 2 times per second to see if auth is completed
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-answer:
|
||||
if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
|
||||
logger.Log(0, "Error during message writing:", err.Error())
|
||||
}
|
||||
case <-timeout:
|
||||
logger.Log(0, "Authentication server time out for headless SSO login")
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "Error during message writing:", err.Error())
|
||||
}
|
||||
}
|
||||
if err = netcache.Del(stateStr); err != nil {
|
||||
logger.Log(0, "failed to remove SSO cache entry", err.Error())
|
||||
}
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "write close:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// == private methods ==
|
||||
|
||||
func addUser(email string) error {
|
||||
var hasSuperAdmin, err = logic.HasSuperAdmin()
|
||||
if err != nil {
|
||||
slog.Error("error checking for existence of admin user during OAuth login for", "email", email, "error", err)
|
||||
return err
|
||||
} // generate random password to adapt to current model
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
slog.Error("failed to get password", "error", err.Error())
|
||||
return fetchErr
|
||||
}
|
||||
var newUser = models.User{
|
||||
UserName: email,
|
||||
Password: newPass,
|
||||
}
|
||||
if !hasSuperAdmin { // must be first attempt, create a superadmin
|
||||
logger.Log(0, "creating superadmin")
|
||||
if err = logic.CreateSuperAdmin(&newUser); err != nil {
|
||||
slog.Error("error creating super admin from user", "email", email, "error", err)
|
||||
} else {
|
||||
slog.Info("superadmin created from user", "email", email)
|
||||
}
|
||||
} else { // otherwise add to db as admin..?
|
||||
// TODO: add ability to add users with preemptive permissions
|
||||
newUser.IsAdmin = false
|
||||
if err = logic.CreateUser(&newUser); err != nil {
|
||||
logger.Log(0, "error creating user,", email, "; user not added", "error", err.Error())
|
||||
} else {
|
||||
logger.Log(0, "user created from ", email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchPassValue(newValue string) (string, error) {
|
||||
|
||||
type valueHolder struct {
|
||||
|
@ -296,54 +55,56 @@ func FetchPassValue(newValue string) (string, error) {
|
|||
return string(b64CurrentValue), nil
|
||||
}
|
||||
|
||||
func getStateAndCode(r *http.Request) (string, string) {
|
||||
var state, code string
|
||||
if r.FormValue("state") != "" && r.FormValue("code") != "" {
|
||||
state = r.FormValue("state")
|
||||
code = r.FormValue("code")
|
||||
} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {
|
||||
state = r.URL.Query().Get("state")
|
||||
code = r.URL.Query().Get("code")
|
||||
}
|
||||
// == private ==
|
||||
|
||||
return state, code
|
||||
}
|
||||
|
||||
func (user *OAuthUser) getUserName() string {
|
||||
var userName string
|
||||
if user.Email != "" {
|
||||
userName = user.Email
|
||||
} else if user.Login != "" {
|
||||
userName = user.Login
|
||||
} else if user.UserPrincipalName != "" {
|
||||
userName = user.UserPrincipalName
|
||||
} else if user.Name != "" {
|
||||
userName = user.Name
|
||||
func addUser(email string) error {
|
||||
var hasSuperAdmin, err = logic.HasSuperAdmin()
|
||||
if err != nil {
|
||||
slog.Error("error checking for existence of admin user during OAuth login for", "email", email, "error", err)
|
||||
return err
|
||||
} // generate random password to adapt to current model
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
slog.Error("failed to get password", "error", fetchErr.Error())
|
||||
return fetchErr
|
||||
}
|
||||
return userName
|
||||
}
|
||||
|
||||
func isStateCached(state string) bool {
|
||||
_, err := netcache.Get(state)
|
||||
return err == nil || strings.Contains(err.Error(), "expired")
|
||||
}
|
||||
|
||||
// isEmailAllowed - checks if email is allowed to signup
|
||||
func isEmailAllowed(email string) bool {
|
||||
allowedDomains := servercfg.GetAllowedEmailDomains()
|
||||
domains := strings.Split(allowedDomains, ",")
|
||||
if len(domains) == 1 && domains[0] == "*" {
|
||||
return true
|
||||
var newUser = models.User{
|
||||
UserName: email,
|
||||
Password: newPass,
|
||||
}
|
||||
emailParts := strings.Split(email, "@")
|
||||
if len(emailParts) < 2 {
|
||||
return false
|
||||
}
|
||||
baseDomainOfEmail := emailParts[1]
|
||||
for _, domain := range domains {
|
||||
if domain == baseDomainOfEmail {
|
||||
return true
|
||||
if !hasSuperAdmin { // must be first attempt, create a superadmin
|
||||
logger.Log(0, "creating superadmin")
|
||||
if err = logic.CreateSuperAdmin(&newUser); err != nil {
|
||||
slog.Error("error creating super admin from user", "email", email, "error", err)
|
||||
} else {
|
||||
slog.Info("superadmin created from user", "email", email)
|
||||
}
|
||||
} else { // otherwise add to db as admin..?
|
||||
// TODO: add ability to add users with preemptive permissions
|
||||
newUser.IsAdmin = false
|
||||
if err = logic.CreateUser(&newUser); err != nil {
|
||||
logger.Log(0, "error creating user,", email, "; user not added", "error", err.Error())
|
||||
} else {
|
||||
logger.Log(0, "user created from ", email)
|
||||
}
|
||||
}
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
func isUserIsAllowed(username, network string, shouldAddUser bool) (*models.User, error) {
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
if err != nil && shouldAddUser { // user must not exist, so try to make one
|
||||
if err = addUser(username); err != nil {
|
||||
logger.Log(0, "failed to add user", username, "during a node SSO network join on network", network)
|
||||
// response := returnErrTemplate(user.UserName, "failed to add user", state, reqKeyIf)
|
||||
// w.WriteHeader(http.StatusInternalServerError)
|
||||
// w.Write(response)
|
||||
return nil, fmt.Errorf("failed to add user to system")
|
||||
}
|
||||
logger.Log(0, "user", username, "was added during a node SSO network join on network", network)
|
||||
user, _ = logic.GetUser(username)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
|
|
@ -121,7 +121,7 @@ func SessionHandler(conn *websocket.Conn) {
|
|||
return
|
||||
}
|
||||
logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "via SSO")
|
||||
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
|
||||
redirectUrl := fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
|
||||
err = conn.WriteMessage(messageType, []byte(redirectUrl))
|
||||
if err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
|
|
|
@ -202,7 +202,7 @@ func Authorize(hostAllowed, networkCheck bool, authNetwork string, next http.Han
|
|||
}
|
||||
|
||||
isnetadmin := issuperadmin || isadmin
|
||||
if errN == nil && (issuperadmin || isadmin) {
|
||||
if issuperadmin || isadmin {
|
||||
nodeID = "mastermac"
|
||||
isAuthorized = true
|
||||
r.Header.Set("ismasterkey", "yes")
|
||||
|
|
|
@ -32,10 +32,6 @@ func userHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/headless", auth.HandleHeadlessSSO)
|
||||
r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterHostSSO).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/users_pending/user/{username}", logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser))).Methods(http.MethodDelete)
|
||||
|
@ -119,7 +115,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
|
|||
successJSONResponse, jsonError := json.Marshal(successResponse)
|
||||
if jsonError != nil {
|
||||
logger.Log(0, username,
|
||||
"error marshalling resp: ", err.Error())
|
||||
"error marshalling resp: ", jsonError.Error())
|
||||
logic.ReturnErrorResponse(response, request, errorResponse)
|
||||
return
|
||||
}
|
||||
|
|
8
main.go
8
main.go
|
@ -12,7 +12,6 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/config"
|
||||
controller "github.com/gravitl/netmaker/controllers"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
|
@ -91,13 +90,6 @@ func initialize() { // Client Mode Prereq Check
|
|||
|
||||
logic.SetJWTSecret()
|
||||
|
||||
var authProvider = auth.InitializeAuthProvider()
|
||||
if authProvider != "" {
|
||||
logger.Log(0, "OAuth provider,", authProvider+",", "initialized")
|
||||
} else {
|
||||
logger.Log(0, "no OAuth provider found or not configured, continuing without OAuth")
|
||||
}
|
||||
|
||||
err = serverctl.SetDefaults()
|
||||
if err != nil {
|
||||
logger.FatalLog("error setting defaults: ", err.Error())
|
||||
|
|
276
pro/auth/auth.go
Normal file
276
pro/auth/auth.go
Normal file
|
@ -0,0 +1,276 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// == consts ==
|
||||
const (
|
||||
init_provider = "initprovider"
|
||||
get_user_info = "getuserinfo"
|
||||
handle_callback = "handlecallback"
|
||||
handle_login = "handlelogin"
|
||||
google_provider_name = "google"
|
||||
azure_ad_provider_name = "azure-ad"
|
||||
github_provider_name = "github"
|
||||
oidc_provider_name = "oidc"
|
||||
verify_user = "verifyuser"
|
||||
user_signin_length = 16
|
||||
node_signin_length = 64
|
||||
headless_signin_length = 32
|
||||
)
|
||||
|
||||
// OAuthUser - generic OAuth strategy user
|
||||
type OAuthUser struct {
|
||||
Name string `json:"name" bson:"name"`
|
||||
Email string `json:"email" bson:"email"`
|
||||
Login string `json:"login" bson:"login"`
|
||||
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
|
||||
AccessToken string `json:"accesstoken" bson:"accesstoken"`
|
||||
}
|
||||
|
||||
var (
|
||||
auth_provider *oauth2.Config
|
||||
upgrader = websocket.Upgrader{}
|
||||
)
|
||||
|
||||
func getCurrentAuthFunctions() map[string]interface{} {
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var authProvider = authInfo[0]
|
||||
switch authProvider {
|
||||
case google_provider_name:
|
||||
return google_functions
|
||||
case azure_ad_provider_name:
|
||||
return azure_ad_functions
|
||||
case github_provider_name:
|
||||
return github_functions
|
||||
case oidc_provider_name:
|
||||
return oidc_functions
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeAuthProvider - initializes the auth provider if any is present
|
||||
func InitializeAuthProvider() string {
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return ""
|
||||
}
|
||||
logger.Log(0, "setting oauth secret")
|
||||
var err = logic.SetAuthSecret(logic.RandomString(64))
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to set auth_secret", err.Error())
|
||||
}
|
||||
var authInfo = servercfg.GetAuthProviderInfo()
|
||||
var serverConn = servercfg.GetAPIHost()
|
||||
if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") {
|
||||
serverConn = "http://" + serverConn
|
||||
logger.Log(1, "localhost OAuth detected, proceeding with insecure http redirect: (", serverConn, ")")
|
||||
} else {
|
||||
serverConn = "https://" + serverConn
|
||||
logger.Log(1, "external OAuth detected, proceeding with https redirect: ("+serverConn+")")
|
||||
}
|
||||
|
||||
if authInfo[0] == "oidc" {
|
||||
functions[init_provider].(func(string, string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2], authInfo[3])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
functions[init_provider].(func(string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2])
|
||||
return authInfo[0]
|
||||
}
|
||||
|
||||
// HandleAuthCallback - handles oauth callback
|
||||
// Note: not included in API reference as part of the OAuth process itself.
|
||||
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if auth_provider == nil {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
state, _ := getStateAndCode(r)
|
||||
_, err := netcache.Get(state) // if in netcache proceeed with node registration login
|
||||
if err == nil || errors.Is(err, netcache.ErrExpired) {
|
||||
switch len(state) {
|
||||
case node_signin_length:
|
||||
logger.Log(1, "proceeding with host SSO callback")
|
||||
HandleHostSSOCallback(w, r)
|
||||
case headless_signin_length:
|
||||
logger.Log(1, "proceeding with headless SSO callback")
|
||||
HandleHeadlessSSOCallback(w, r)
|
||||
default:
|
||||
logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
|
||||
}
|
||||
} else { // handle normal login
|
||||
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// swagger:route GET /api/oauth/login nodes HandleAuthLogin
|
||||
//
|
||||
// Handles OAuth login.
|
||||
//
|
||||
// Schemes: https
|
||||
//
|
||||
// Security:
|
||||
// oauth
|
||||
// Responses:
|
||||
// 200: okResponse
|
||||
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if auth_provider == nil {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
var functions = getCurrentAuthFunctions()
|
||||
if functions == nil {
|
||||
return
|
||||
}
|
||||
if servercfg.GetFrontendURL() == "" {
|
||||
handleOauthNotConfigured(w)
|
||||
return
|
||||
}
|
||||
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
|
||||
}
|
||||
|
||||
// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
|
||||
func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
logger.Log(0, "failed to establish web-socket connection during headless sign-in")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := &netcache.CValue{User: "", Pass: ""}
|
||||
stateStr := logic.RandomString(headless_signin_length)
|
||||
if err = netcache.Set(stateStr, req); err != nil {
|
||||
logger.Log(0, "Failed to process sso request -", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
timeout := make(chan bool, 1)
|
||||
answer := make(chan string, 1)
|
||||
defer close(answer)
|
||||
defer close(timeout)
|
||||
|
||||
if auth_provider == nil {
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
|
||||
if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
|
||||
logger.Log(0, "error during message writing:", err.Error())
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
cachedReq, err := netcache.Get(stateStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "expired") {
|
||||
logger.Log(0, "timeout occurred while waiting for SSO")
|
||||
timeout <- true
|
||||
break
|
||||
}
|
||||
continue
|
||||
} else if cachedReq.Pass != "" {
|
||||
logger.Log(0, "SSO process completed for user ", cachedReq.User)
|
||||
answer <- cachedReq.Pass
|
||||
break
|
||||
}
|
||||
time.Sleep(500) // try it 2 times per second to see if auth is completed
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-answer:
|
||||
if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
|
||||
logger.Log(0, "Error during message writing:", err.Error())
|
||||
}
|
||||
case <-timeout:
|
||||
logger.Log(0, "Authentication server time out for headless SSO login")
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "Error during message writing:", err.Error())
|
||||
}
|
||||
}
|
||||
if err = netcache.Del(stateStr); err != nil {
|
||||
logger.Log(0, "failed to remove SSO cache entry", err.Error())
|
||||
}
|
||||
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
logger.Log(0, "write close:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// == private methods ==
|
||||
|
||||
func getStateAndCode(r *http.Request) (string, string) {
|
||||
var state, code string
|
||||
if r.FormValue("state") != "" && r.FormValue("code") != "" {
|
||||
state = r.FormValue("state")
|
||||
code = r.FormValue("code")
|
||||
} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {
|
||||
state = r.URL.Query().Get("state")
|
||||
code = r.URL.Query().Get("code")
|
||||
}
|
||||
|
||||
return state, code
|
||||
}
|
||||
|
||||
func (user *OAuthUser) getUserName() string {
|
||||
var userName string
|
||||
if user.Email != "" {
|
||||
userName = user.Email
|
||||
} else if user.Login != "" {
|
||||
userName = user.Login
|
||||
} else if user.UserPrincipalName != "" {
|
||||
userName = user.UserPrincipalName
|
||||
} else if user.Name != "" {
|
||||
userName = user.Name
|
||||
}
|
||||
return userName
|
||||
}
|
||||
|
||||
func isStateCached(state string) bool {
|
||||
_, err := netcache.Get(state)
|
||||
return err == nil || strings.Contains(err.Error(), "expired")
|
||||
}
|
||||
|
||||
// isEmailAllowed - checks if email is allowed to signup
|
||||
func isEmailAllowed(email string) bool {
|
||||
allowedDomains := servercfg.GetAllowedEmailDomains()
|
||||
domains := strings.Split(allowedDomains, ",")
|
||||
if len(domains) == 1 && domains[0] == "*" {
|
||||
return true
|
||||
}
|
||||
emailParts := strings.Split(email, "@")
|
||||
if len(emailParts) < 2 {
|
||||
return false
|
||||
}
|
||||
baseDomainOfEmail := emailParts[1]
|
||||
for _, domain := range domains {
|
||||
if domain == baseDomainOfEmail {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -101,7 +102,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
|||
handleOauthUserNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
var newPass, fetchErr = auth.FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -101,7 +102,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
|||
handleOauthUserNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
var newPass, fetchErr = auth.FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -104,7 +105,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
handleOauthUserNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
var newPass, fetchErr = auth.FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
|
@ -62,7 +63,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write(response)
|
||||
return
|
||||
}
|
||||
newPass, fetchErr := FetchPassValue("")
|
||||
newPass, fetchErr := auth.FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -114,7 +115,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
|||
handleOauthUserNotAllowed(w)
|
||||
return
|
||||
}
|
||||
var newPass, fetchErr = FetchPassValue("")
|
||||
var newPass, fetchErr = auth.FetchPassValue("")
|
||||
if fetchErr != nil {
|
||||
return
|
||||
}
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -156,23 +155,3 @@ func RegisterHostSSO(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
http.Redirect(w, r, auth_provider.AuthCodeURL(machineKeyStr), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// == private ==
|
||||
|
||||
func isUserIsAllowed(username, network string, shouldAddUser bool) (*models.User, error) {
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
if err != nil && shouldAddUser { // user must not exist, so try to make one
|
||||
if err = addUser(username); err != nil {
|
||||
logger.Log(0, "failed to add user", username, "during a node SSO network join on network", network)
|
||||
// response := returnErrTemplate(user.UserName, "failed to add user", state, reqKeyIf)
|
||||
// w.WriteHeader(http.StatusInternalServerError)
|
||||
// w.Write(response)
|
||||
return nil, fmt.Errorf("failed to add user to system")
|
||||
}
|
||||
logger.Log(0, "user", username, "was added during a node SSO network join on network", network)
|
||||
user, _ = logic.GetUser(username)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/pro/auth"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
|
@ -20,6 +21,10 @@ func UserHandlers(r *mux.Router) {
|
|||
r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/users/{username}/remote_access_gw", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGws)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users/ingress/{ingress_id}", logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/oauth/headless", auth.HandleHeadlessSSO)
|
||||
r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterHostSSO).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// swagger:route POST /api/users/{username}/remote_access_gw user attachUserToRemoteAccessGateway
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/pro/auth"
|
||||
proControllers "github.com/gravitl/netmaker/pro/controllers"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
|
@ -81,6 +82,13 @@ func InitPro() {
|
|||
AddRacHooks()
|
||||
}
|
||||
|
||||
var authProvider = auth.InitializeAuthProvider()
|
||||
if authProvider != "" {
|
||||
slog.Info("OAuth provider,", authProvider+",", "initialized")
|
||||
} else {
|
||||
slog.Error("no OAuth provider found or not configured, continuing without OAuth")
|
||||
}
|
||||
|
||||
})
|
||||
logic.ResetFailOver = proLogic.ResetFailOver
|
||||
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer
|
||||
|
|
Loading…
Reference in a new issue