mirror of
https://github.com/gravitl/netmaker.git
synced 2024-09-21 07:46:04 +08:00
Merge pull request #1444 from gravitl/feature_v0.14.7_ha_sso
added better state management to make OAuth sign-ins HA
This commit is contained in:
commit
61553d70ab
|
@ -29,7 +29,6 @@ const (
|
||||||
auth_key = "netmaker_auth"
|
auth_key = "netmaker_auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
|
|
||||||
var auth_provider *oauth2.Config
|
var auth_provider *oauth2.Config
|
||||||
|
|
||||||
func getCurrentAuthFunctions() map[string]interface{} {
|
func getCurrentAuthFunctions() map[string]interface{} {
|
||||||
|
|
|
@ -41,7 +41,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
|
func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
oauth_state_string = logic.RandomString(16)
|
var oauth_state_string = logic.RandomString(16)
|
||||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
|
@ -49,6 +49,12 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logic.SetState(oauth_state_string); err != nil {
|
||||||
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
@ -88,7 +94,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
|
func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
|
||||||
if state != oauth_state_string {
|
oauth_state_string, isValid := logic.IsStateValid(state)
|
||||||
|
if !isValid || state != oauth_state_string {
|
||||||
return nil, fmt.Errorf("invalid oauth state")
|
return nil, fmt.Errorf("invalid oauth state")
|
||||||
}
|
}
|
||||||
var token, err = auth_provider.Exchange(context.Background(), code)
|
var token, err = auth_provider.Exchange(context.Background(), code)
|
||||||
|
|
|
@ -41,7 +41,7 @@ func initGithub(redirectURL string, clientID string, clientSecret string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
|
func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
oauth_state_string = logic.RandomString(16)
|
var oauth_state_string = logic.RandomString(16)
|
||||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
|
@ -49,6 +49,12 @@ func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logic.SetState(oauth_state_string); err != nil {
|
||||||
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
@ -88,7 +94,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
|
func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
|
||||||
if state != oauth_state_string {
|
oauth_state_string, isValid := logic.IsStateValid(state)
|
||||||
|
if !isValid || state != oauth_state_string {
|
||||||
return nil, fmt.Errorf("invalid OAuth state")
|
return nil, fmt.Errorf("invalid OAuth state")
|
||||||
}
|
}
|
||||||
var token, err = auth_provider.Exchange(context.Background(), code)
|
var token, err = auth_provider.Exchange(context.Background(), code)
|
||||||
|
|
|
@ -42,7 +42,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
oauth_state_string = logic.RandomString(16)
|
var oauth_state_string = logic.RandomString(16)
|
||||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
|
@ -50,6 +50,12 @@ func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logic.SetState(oauth_state_string); err != nil {
|
||||||
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
@ -89,7 +95,8 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
|
func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
|
||||||
if state != oauth_state_string {
|
oauth_state_string, isValid := logic.IsStateValid(state)
|
||||||
|
if !isValid || state != oauth_state_string {
|
||||||
return nil, fmt.Errorf("invalid OAuth state")
|
return nil, fmt.Errorf("invalid OAuth state")
|
||||||
}
|
}
|
||||||
var token, err = auth_provider.Exchange(context.Background(), code)
|
var token, err = auth_provider.Exchange(context.Background(), code)
|
||||||
|
|
11
auth/oidc.go
11
auth/oidc.go
|
@ -54,7 +54,7 @@ func initOIDC(redirectURL string, clientID string, clientSecret string, issuer s
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
|
func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
oauth_state_string = logic.RandomString(16)
|
var oauth_state_string = logic.RandomString(16)
|
||||||
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
if auth_provider == nil && servercfg.GetFrontendURL() != "" {
|
||||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
return
|
return
|
||||||
|
@ -62,6 +62,12 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := logic.SetState(oauth_state_string); err != nil {
|
||||||
|
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
var url = auth_provider.AuthCodeURL(oauth_state_string)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
@ -101,7 +107,8 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) {
|
func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) {
|
||||||
if state != oauth_state_string {
|
oauth_state_string, isValid := logic.IsStateValid(state)
|
||||||
|
if !isValid || state != oauth_state_string {
|
||||||
return nil, fmt.Errorf("invalid OAuth state")
|
return nil, fmt.Errorf("invalid OAuth state")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,9 @@ const GENERATED_TABLE_NAME = "generated"
|
||||||
// NODE_ACLS_TABLE_NAME - stores the node ACL rules
|
// NODE_ACLS_TABLE_NAME - stores the node ACL rules
|
||||||
const NODE_ACLS_TABLE_NAME = "nodeacls"
|
const NODE_ACLS_TABLE_NAME = "nodeacls"
|
||||||
|
|
||||||
|
// SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins
|
||||||
|
const SSO_STATE_CACHE = "ssostatecache"
|
||||||
|
|
||||||
// == ERROR CONSTS ==
|
// == ERROR CONSTS ==
|
||||||
|
|
||||||
// NO_RECORD - no singular result found
|
// NO_RECORD - no singular result found
|
||||||
|
@ -135,6 +138,7 @@ func createTables() {
|
||||||
createTable(SERVER_UUID_TABLE_NAME)
|
createTable(SERVER_UUID_TABLE_NAME)
|
||||||
createTable(GENERATED_TABLE_NAME)
|
createTable(GENERATED_TABLE_NAME)
|
||||||
createTable(NODE_ACLS_TABLE_NAME)
|
createTable(NODE_ACLS_TABLE_NAME)
|
||||||
|
createTable(SSO_STATE_CACHE)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTable(tableName string) error {
|
func createTable(tableName string) error {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/gravitl/netmaker/database"
|
"github.com/gravitl/netmaker/database"
|
||||||
|
@ -270,3 +271,52 @@ func FetchAuthSecret(key string, secret string) (string, error) {
|
||||||
}
|
}
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetState - gets an SsoState from DB, if expired returns error
|
||||||
|
func GetState(state string) (*models.SsoState, error) {
|
||||||
|
var s models.SsoState
|
||||||
|
record, err := database.FetchRecord(database.SSO_STATE_CACHE, state)
|
||||||
|
if err != nil {
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal([]byte(record), &s); err != nil {
|
||||||
|
return &s, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.IsExpired() {
|
||||||
|
return &s, fmt.Errorf("state expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetState - sets a state with new expiration
|
||||||
|
func SetState(state string) error {
|
||||||
|
s := models.SsoState{
|
||||||
|
Value: state,
|
||||||
|
Expiration: time.Now().Add(models.DefaultExpDuration),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(&s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return database.Insert(state, string(data), database.SSO_STATE_CACHE)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsStateValid - checks if given state is valid or not
|
||||||
|
// deletes state after call is made to clean up, should only be called once per sign-in
|
||||||
|
func IsStateValid(state string) (string, bool) {
|
||||||
|
s, err := GetState(state)
|
||||||
|
if s.Value != "" {
|
||||||
|
delState(state)
|
||||||
|
}
|
||||||
|
return s.Value, err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delState - removes a state from cache/db
|
||||||
|
func delState(state string) error {
|
||||||
|
return database.DeleteRecord(database.SSO_STATE_CACHE, state)
|
||||||
|
}
|
||||||
|
|
17
models/ssocache.go
Normal file
17
models/ssocache.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// DefaultExpDuration - the default expiration time of SsoState
|
||||||
|
const DefaultExpDuration = time.Minute * 5
|
||||||
|
|
||||||
|
// SsoState - holds SSO sign-in session data
|
||||||
|
type SsoState struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
Expiration time.Time `json:"expiration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoState.IsExpired - tells if an SsoState is expired or not
|
||||||
|
func (s *SsoState) IsExpired() bool {
|
||||||
|
return time.Now().After(s.Expiration)
|
||||||
|
}
|
Loading…
Reference in a new issue