diff --git a/auth/auth.go b/auth/auth.go index 2725c679..52065037 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -29,7 +29,6 @@ const ( auth_key = "netmaker_auth" ) -var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login var auth_provider *oauth2.Config func getCurrentAuthFunctions() map[string]interface{} { diff --git a/auth/azure-ad.go b/auth/azure-ad.go index f828d3bc..b2931b50 100644 --- a/auth/azure-ad.go +++ b/auth/azure-ad.go @@ -41,7 +41,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) { } func handleAzureLogin(w http.ResponseWriter, r *http.Request) { - oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(16) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -49,6 +49,12 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) return } + + if err := logic.SetState(oauth_state_string); err != nil { + http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -88,7 +94,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { } func getAzureUserInfo(state string, code string) (*azureOauthUser, error) { - if state != oauth_state_string { + oauth_state_string, isValid := logic.IsStateValid(state) + if !isValid || state != oauth_state_string { return nil, fmt.Errorf("invalid oauth state") } var token, err = auth_provider.Exchange(context.Background(), code) diff --git a/auth/github.go b/auth/github.go index dfb98241..2bbdfdea 100644 --- a/auth/github.go +++ b/auth/github.go @@ -41,7 +41,7 @@ func initGithub(redirectURL string, clientID string, clientSecret string) { } func handleGithubLogin(w http.ResponseWriter, r *http.Request) { - oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(16) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -49,6 +49,12 @@ func handleGithubLogin(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) return } + + if err := logic.SetState(oauth_state_string); err != nil { + http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -88,7 +94,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { } func getGithubUserInfo(state string, code string) (*githubOauthUser, error) { - if state != oauth_state_string { + oauth_state_string, isValid := logic.IsStateValid(state) + if !isValid || state != oauth_state_string { return nil, fmt.Errorf("invalid OAuth state") } var token, err = auth_provider.Exchange(context.Background(), code) diff --git a/auth/google.go b/auth/google.go index 4397ccaf..344c9938 100644 --- a/auth/google.go +++ b/auth/google.go @@ -42,7 +42,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) { } func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { - oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(16) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -50,6 +50,12 @@ func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) return } + + if err := logic.SetState(oauth_state_string); err != nil { + http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -89,7 +95,8 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { } func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) { - if state != oauth_state_string { + oauth_state_string, isValid := logic.IsStateValid(state) + if !isValid || state != oauth_state_string { return nil, fmt.Errorf("invalid OAuth state") } var token, err = auth_provider.Exchange(context.Background(), code) diff --git a/auth/oidc.go b/auth/oidc.go index ce87a574..77e26ad9 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -54,7 +54,7 @@ func initOIDC(redirectURL string, clientID string, clientSecret string, issuer s } func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { - oauth_state_string = logic.RandomString(16) + var oauth_state_string = logic.RandomString(16) if auth_provider == nil && servercfg.GetFrontendURL() != "" { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return @@ -62,6 +62,12 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials")) return } + + if err := logic.SetState(oauth_state_string); err != nil { + http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) + return + } + var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } @@ -101,7 +107,8 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { } func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) { - if state != oauth_state_string { + oauth_state_string, isValid := logic.IsStateValid(state) + if !isValid || state != oauth_state_string { return nil, fmt.Errorf("invalid OAuth state") } diff --git a/database/database.go b/database/database.go index c01705b7..2b3b0d4a 100644 --- a/database/database.go +++ b/database/database.go @@ -56,6 +56,9 @@ const GENERATED_TABLE_NAME = "generated" // NODE_ACLS_TABLE_NAME - stores the node ACL rules const NODE_ACLS_TABLE_NAME = "nodeacls" +// SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins +const SSO_STATE_CACHE = "ssostatecache" + // == ERROR CONSTS == // NO_RECORD - no singular result found @@ -135,6 +138,7 @@ func createTables() { createTable(SERVER_UUID_TABLE_NAME) createTable(GENERATED_TABLE_NAME) createTable(NODE_ACLS_TABLE_NAME) + createTable(SSO_STATE_CACHE) } func createTable(tableName string) error { diff --git a/logic/auth.go b/logic/auth.go index 75fcd975..ba8205bc 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/go-playground/validator/v10" "github.com/gravitl/netmaker/database" @@ -270,3 +271,52 @@ func FetchAuthSecret(key string, secret string) (string, error) { } return record, nil } + +// GetState - gets an SsoState from DB, if expired returns error +func GetState(state string) (*models.SsoState, error) { + var s models.SsoState + record, err := database.FetchRecord(database.SSO_STATE_CACHE, state) + if err != nil { + return &s, err + } + + if err = json.Unmarshal([]byte(record), &s); err != nil { + return &s, err + } + + if s.IsExpired() { + return &s, fmt.Errorf("state expired") + } + + return &s, nil +} + +// SetState - sets a state with new expiration +func SetState(state string) error { + s := models.SsoState{ + Value: state, + Expiration: time.Now().Add(models.DefaultExpDuration), + } + + data, err := json.Marshal(&s) + if err != nil { + return err + } + + return database.Insert(state, string(data), database.SSO_STATE_CACHE) +} + +// IsStateValid - checks if given state is valid or not +// deletes state after call is made to clean up, should only be called once per sign-in +func IsStateValid(state string) (string, bool) { + s, err := GetState(state) + if s.Value != "" { + delState(state) + } + return s.Value, err == nil +} + +// delState - removes a state from cache/db +func delState(state string) error { + return database.DeleteRecord(database.SSO_STATE_CACHE, state) +} diff --git a/models/ssocache.go b/models/ssocache.go new file mode 100644 index 00000000..90e61285 --- /dev/null +++ b/models/ssocache.go @@ -0,0 +1,17 @@ +package models + +import "time" + +// DefaultExpDuration - the default expiration time of SsoState +const DefaultExpDuration = time.Minute * 5 + +// SsoState - holds SSO sign-in session data +type SsoState struct { + Value string `json:"value"` + Expiration time.Time `json:"expiration"` +} + +// SsoState.IsExpired - tells if an SsoState is expired or not +func (s *SsoState) IsExpired() bool { + return time.Now().After(s.Expiration) +}