netmaker/pro/auth/auth.go
Vishal Dalwadi 614cf77b5a
NET-1991: Add IDP sync functionality. (#3428)
* feat: api access tokens

* revoke all user tokens

* redefine access token api routes, add auto egress option to enrollment keys

* add server settings apis, add db table for settigs

* handle server settings updates

* switch to using settings from DB

* fix sever settings migration

* revet force migration for settings

* fix server settings database write

* fix revoked tokens to be unauthorized

* remove unused functions

* convert access token to sql schema

* switch access token to sql schema

* fix merge conflicts

* fix server settings types

* bypass basic auth setting for super admin

* add TODO comment

* feat(go): add types for idp package;

* feat(go): import azure sdk;

* feat(go): add stub for google workspace client;

* feat(go): implement azure ad client;

* feat(go): sync users and groups using idp client;

* publish peer update on settings update

* feat(go): read creds from env vars;

* feat(go): add api endpoint to trigger idp sync;

* fix(go): sync member changes;

* fix(go): handle error;

* fix(go): set correct response type;

* feat(go): support disabling user accounts;

1. Add api endpoints to enable and disable user accounts.
2. Add checks in authenticators to prevent disabled users from logging in.
3. Add checks in middleware to prevent api usage by disabled users.

* feat(go): use string slice for group members;

* feat(go): sync user account status from idp;

* feat(go): import google admin sdk;

* feat(go): add support for google workspace idp;

* feat(go): initialize idp client on sync;

* feat(go): sync from idp periodically;

* feat(go): improvements for google idp;

1. Use the impersonate package to authenticate.
2. Use Pages method to get all data.

* chore(go): import style changes from migration branch;

1. Singular file names for table schema.
2. No table name method.
3. Use .Model instead of .Table.
4. No unnecessary tagging.

* remove nat check on egress gateway request

* Revert "remove nat check on egress gateway request"

This reverts commit 0aff12a189.

* feat(go): add db middleware;

* feat(go): restore method;

* feat(go): add user access token schema;

* fix user auth api:

* re initalise oauth and email config

* feat(go): fetch idp creds from server settings;

* feat(go): add filters for users and groups;

* feat(go): skip sync from idp if disabled;

* feat(go): add endpoint to remove idp integration;

* feat(go): import all users if no filters;

* feat(go): assign service-user role on sync;

* feat(go): remove microsoft-go-sdk;

* feat(go): add display name field for user;

* fix(go): set account disabled correctly;

* fix(go): update user if display name changes;

* fix(go): remove auth provider when removing idp integration;

* fix(go): ignore display name if empty;

* feat(go): add idp sync interval setting;

* fix(go): error on invalid auth provider;

* fix(go): no error if no user on group delete;

* fix(go): check superadmin using platform role id;

* feat(go): add display name and account disabled to return user as well;

* feat(go): tidy go mod after merge;

* feat(go): reinitialize auth provider and idp sync hook;

* fix(go): merge error;

* fix(go): merge error;

* feat(go): use id as the external provider id;

* fix(go): comments;

* feat(go): add function to return pending users;

* feat(go): prevent external id erasure;

* fix(go): user and group sync errors;

* chore(go): cleanup;

* fix(go): delete only oauth users;

* feat(go): use uuid group id;

* export ipd id to in rest api

* feat(go): don't use uuid for default groups;

* feat(go): migrate group only if id not uuid;

* chore(go): go mod tidy;

---------

Co-authored-by: abhishek9686 <abhi281342@gmail.com>
Co-authored-by: Abhishek K <abhishek@netmaker.io>
Co-authored-by: the_aceix <aceixsmartx@gmail.com>
2025-05-21 13:48:15 +05:30

306 lines
8.7 KiB
Go

package auth
import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/websocket"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
)
// == consts ==
const (
init_provider = "initprovider"
get_user_info = "getuserinfo"
handle_callback = "handlecallback"
handle_login = "handlelogin"
google_provider_name = "google"
azure_ad_provider_name = "azure-ad"
github_provider_name = "github"
oidc_provider_name = "oidc"
verify_user = "verifyuser"
user_signin_length = 16
node_signin_length = 64
headless_signin_length = 32
)
// OAuthUser - generic OAuth strategy user
type OAuthUser struct {
ID string `json:"id" bson:"id"`
Name string `json:"name" bson:"name"`
Email string `json:"email" bson:"email"`
Login string `json:"login" bson:"login"`
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
AccessToken string `json:"accesstoken" bson:"accesstoken"`
}
var (
auth_provider *oauth2.Config
upgrader = websocket.Upgrader{}
)
func getCurrentAuthFunctions() map[string]interface{} {
var authInfo = logic.GetAuthProviderInfo(logic.GetServerSettings())
var authProvider = authInfo[0]
switch authProvider {
case google_provider_name:
return google_functions
case azure_ad_provider_name:
return azure_ad_functions
case github_provider_name:
return github_functions
case oidc_provider_name:
return oidc_functions
default:
return nil
}
}
// ResetAuthProvider resets the auth provider configuration.
func ResetAuthProvider() {
settings := logic.GetServerSettings()
if settings.AuthProvider == "" {
auth_provider = nil
}
InitializeAuthProvider()
}
// InitializeAuthProvider - initializes the auth provider if any is present
func InitializeAuthProvider() string {
var functions = getCurrentAuthFunctions()
if functions == nil {
return ""
}
logger.Log(0, "setting oauth secret")
var err = logic.SetAuthSecret(logic.RandomString(64))
if err != nil {
logger.FatalLog("failed to set auth_secret", err.Error())
}
var authInfo = logic.GetAuthProviderInfo(logic.GetServerSettings())
var serverConn = servercfg.GetAPIHost()
if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") {
serverConn = "http://" + serverConn
logger.Log(1, "localhost OAuth detected, proceeding with insecure http redirect: (", serverConn, ")")
} else {
serverConn = "https://" + serverConn
logger.Log(1, "external OAuth detected, proceeding with https redirect: ("+serverConn+")")
}
if authInfo[0] == "oidc" {
functions[init_provider].(func(string, string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2], authInfo[3])
return authInfo[0]
}
functions[init_provider].(func(string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2])
return authInfo[0]
}
// HandleAuthCallback - handles oauth callback
// Note: not included in API reference as part of the OAuth process itself.
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
if auth_provider == nil {
handleOauthNotConfigured(w)
return
}
var functions = getCurrentAuthFunctions()
if functions == nil {
return
}
state, _ := getStateAndCode(r)
_, err := netcache.Get(state) // if in netcache proceeed with node registration login
if err == nil || errors.Is(err, netcache.ErrExpired) {
switch len(state) {
case node_signin_length:
logger.Log(1, "proceeding with host SSO callback")
HandleHostSSOCallback(w, r)
case headless_signin_length:
logger.Log(1, "proceeding with headless SSO callback")
HandleHeadlessSSOCallback(w, r)
default:
logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
}
} else { // handle normal login
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
}
}
// swagger:route GET /api/oauth/login nodes HandleAuthLogin
//
// Handles OAuth login.
//
// Schemes: https
//
// Security:
// oauth
// Responses:
// 200: okResponse
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
if auth_provider == nil {
handleOauthNotConfigured(w)
return
}
var functions = getCurrentAuthFunctions()
if functions == nil {
return
}
if servercfg.GetFrontendURL() == "" {
handleOauthNotConfigured(w)
return
}
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
}
// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())
return
}
if conn == nil {
logger.Log(0, "failed to establish web-socket connection during headless sign-in")
return
}
defer conn.Close()
req := &netcache.CValue{User: "", Pass: ""}
stateStr := logic.RandomString(headless_signin_length)
if err = netcache.Set(stateStr, req); err != nil {
logger.Log(0, "Failed to process sso request -", err.Error())
return
}
timeout := make(chan bool, 1)
answer := make(chan string, 1)
defer close(answer)
defer close(timeout)
if auth_provider == nil {
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
return
}
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
logger.Log(0, "error during message writing:", err.Error())
}
go func() {
for {
cachedReq, err := netcache.Get(stateStr)
if err != nil {
if strings.Contains(err.Error(), "expired") {
logger.Log(0, "timeout occurred while waiting for SSO")
timeout <- true
break
}
continue
} else if cachedReq.Pass != "" {
logger.Log(0, "SSO process completed for user ", cachedReq.User)
answer <- cachedReq.Pass
break
}
time.Sleep(500) // try it 2 times per second to see if auth is completed
}
}()
select {
case result := <-answer:
if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
logger.Log(0, "Error during message writing:", err.Error())
}
case <-timeout:
logger.Log(0, "Authentication server time out for headless SSO login")
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
logger.Log(0, "Error during message writing:", err.Error())
}
}
if err = netcache.Del(stateStr); err != nil {
logger.Log(0, "failed to remove SSO cache entry", err.Error())
}
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
logger.Log(0, "write close:", err.Error())
}
}
// == private methods ==
func getStateAndCode(r *http.Request) (string, string) {
var state, code string
if r.FormValue("state") != "" && r.FormValue("code") != "" {
state = r.FormValue("state")
code = r.FormValue("code")
} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {
state = r.URL.Query().Get("state")
code = r.URL.Query().Get("code")
}
return state, code
}
func getUserEmailFromClaims(token string) string {
accessToken, _ := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
return []byte(""), nil
})
if accessToken == nil {
return ""
}
claims, _ := accessToken.Claims.(jwt.MapClaims)
if claims == nil {
return ""
}
if claims["email"] == nil {
return ""
}
return claims["email"].(string)
}
func (user *OAuthUser) getUserName() string {
var userName string
if user.Email != "" {
userName = user.Email
} else if user.Login != "" {
userName = user.Login
} else if user.UserPrincipalName != "" {
userName = user.UserPrincipalName
} else if user.Name != "" {
userName = user.Name
}
return userName
}
func isStateCached(state string) bool {
_, err := netcache.Get(state)
return err == nil || strings.Contains(err.Error(), "expired")
}
// isEmailAllowed - checks if email is allowed to signup
func isEmailAllowed(email string) bool {
allowedDomains := logic.GetAllowedEmailDomains()
domains := strings.Split(allowedDomains, ",")
if len(domains) == 1 && domains[0] == "*" {
return true
}
emailParts := strings.Split(email, "@")
if len(emailParts) < 2 {
return false
}
baseDomainOfEmail := emailParts[1]
for _, domain := range domains {
if domain == baseDomainOfEmail {
return true
}
}
return false
}