diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..92728804 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,66 @@ +package auth + +import ( + "net/http" + + "github.com/gravitl/netmaker/servercfg" + "golang.org/x/oauth2" +) + +// == consts == +const ( + init_provider = "initprovider" + get_user_info = "getuserinfo" + handle_callback = "handlecallback" + handle_login = "handlelogin" + oauth_state_string = "netmaker-oauth-state" + google_provider_name = "google" + azure_ad_provider_name = "azure-ad" + github_provider_name = "github" +) + +var auth_provider *oauth2.Config + +func getCurrentAuthFunctions() map[string]interface{} { + var authInfo = servercfg.GetAuthProviderInfo() + var authProvider = authInfo[0] + switch authProvider { + case google_provider_name: + return google_functions + case azure_ad_provider_name: + return google_functions + case github_provider_name: + return google_functions + default: + return nil + } +} + +// InitializeAuthProvider - initializes the auth provider if any is present +func InitializeAuthProvider() bool { + var functions = getCurrentAuthFunctions() + if functions == nil { + return false + } + var authInfo = servercfg.GetAuthProviderInfo() + functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString(), authInfo[1], authInfo[2]) + return auth_provider != nil +} + +// HandleAuthCallback - handles oauth callback +func HandleAuthCallback(w http.ResponseWriter, r *http.Request) { + var functions = getCurrentAuthFunctions() + if functions == nil { + return + } + functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r) +} + +// HandleAuthLogin - handles oauth login +func HandleAuthLogin(w http.ResponseWriter, r *http.Request) { + var functions = getCurrentAuthFunctions() + if functions == nil { + return + } + functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r) +} diff --git a/auth/google.go b/auth/google.go new file mode 100644 index 00000000..6bafb3f3 --- /dev/null +++ b/auth/google.go @@ -0,0 +1,65 @@ +package auth + +import ( + "fmt" + "io/ioutil" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +var google_functions = map[string]interface{}{ + init_provider: initGoogle, + get_user_info: getUserInfo, + handle_callback: handleGoogleCallback, + handle_login: handleGoogleLogin, +} + +// == handle google authentication here == + +func initGoogle(redirectURL string, clientID string, clientSecret string) { + auth_provider = &oauth2.Config{ + RedirectURL: redirectURL, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, + Endpoint: google.Endpoint, + } +} + +func handleGoogleLogin(w http.ResponseWriter, r *http.Request) { + url := auth_provider.AuthCodeURL(oauth_state_string) + http.Redirect(w, r, url, http.StatusTemporaryRedirect) +} + +func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { + + var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code")) + if err != nil { + fmt.Println(err.Error()) + http.Redirect(w, r, "/api/oauth/error", http.StatusTemporaryRedirect) + return + } + fmt.Fprintf(w, "Content: %s\n", content) +} + +func getUserInfo(state string, code string) ([]byte, error) { + if state != oauth_state_string { + return nil, fmt.Errorf("invalid oauth state") + } + token, err := auth_provider.Exchange(oauth2.NoContext, code) + if err != nil { + return nil, fmt.Errorf("code exchange failed: %s", err.Error()) + } + response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed getting user info: %s", err.Error()) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("failed reading response body: %s", err.Error()) + } + return contents, nil +} diff --git a/config/config.go b/config/config.go index 367bbb83..12c5d91a 100644 --- a/config/config.go +++ b/config/config.go @@ -30,49 +30,51 @@ var Config *EnvironmentConfig // EnvironmentConfig : type EnvironmentConfig struct { Server ServerConfig `yaml:"server"` - SQL SQLConfig `yaml:"sql"` + SQL SQLConfig `yaml:"sql"` } // ServerConfig : type ServerConfig struct { - CoreDNSAddr string `yaml:"corednsaddr"` - APIConnString string `yaml:"apiconn"` - APIHost string `yaml:"apihost"` - APIPort string `yaml:"apiport"` - GRPCConnString string `yaml:"grpcconn"` - GRPCHost string `yaml:"grpchost"` - GRPCPort string `yaml:"grpcport"` - GRPCSecure string `yaml:"grpcsecure"` - MasterKey string `yaml:"masterkey"` - AllowedOrigin string `yaml:"allowedorigin"` - NodeID string `yaml:"nodeid"` - RestBackend string `yaml:"restbackend"` - AgentBackend string `yaml:"agentbackend"` - ClientMode string `yaml:"clientmode"` - DNSMode string `yaml:"dnsmode"` - SplitDNS string `yaml:"splitdns"` - DisableRemoteIPCheck string `yaml:"disableremoteipcheck"` - DisableDefaultNet string `yaml:"disabledefaultnet"` - GRPCSSL string `yaml:"grpcssl"` - Version string `yaml:"version"` - SQLConn string `yaml:"sqlconn"` - Platform string `yaml:"platform"` - Database string `yaml:database` - CheckinInterval string `yaml:checkininterval` - DefaultNodeLimit int32 `yaml:"defaultnodelimit"` - Verbosity int32 `yaml:"verbosity"` + CoreDNSAddr string `yaml:"corednsaddr"` + APIConnString string `yaml:"apiconn"` + APIHost string `yaml:"apihost"` + APIPort string `yaml:"apiport"` + GRPCConnString string `yaml:"grpcconn"` + GRPCHost string `yaml:"grpchost"` + GRPCPort string `yaml:"grpcport"` + GRPCSecure string `yaml:"grpcsecure"` + MasterKey string `yaml:"masterkey"` + AllowedOrigin string `yaml:"allowedorigin"` + NodeID string `yaml:"nodeid"` + RestBackend string `yaml:"restbackend"` + AgentBackend string `yaml:"agentbackend"` + ClientMode string `yaml:"clientmode"` + DNSMode string `yaml:"dnsmode"` + SplitDNS string `yaml:"splitdns"` + DisableRemoteIPCheck string `yaml:"disableremoteipcheck"` + DisableDefaultNet string `yaml:"disabledefaultnet"` + GRPCSSL string `yaml:"grpcssl"` + Version string `yaml:"version"` + SQLConn string `yaml:"sqlconn"` + Platform string `yaml:"platform"` + Database string `yaml:database` + CheckinInterval string `yaml:checkininterval` + DefaultNodeLimit int32 `yaml:"defaultnodelimit"` + Verbosity int32 `yaml:"verbosity"` ServerCheckinInterval int64 `yaml:"servercheckininterval"` + AuthProvider string `yaml:"authprovider"` + ClientID string `yaml:"clientid"` + ClientSecret string `yaml:"clientsecret"` } - // Generic SQL Config type SQLConfig struct { - Host string `yaml:"host"` - Port int32 `yaml:"port"` + Host string `yaml:"host"` + Port int32 `yaml:"port"` Username string `yaml:"username"` Password string `yaml:"password"` - DB string `yaml:"db"` - SSLMode string `yaml:"sslmode"` + DB string `yaml:"db"` + SSLMode string `yaml:"sslmode"` } //reading in the env file diff --git a/go.mod b/go.mod index 00453777..59748f11 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/urfave/cli/v2 v2.3.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 // indirect + golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e // indirect golang.org/x/text v0.3.7-0.20210524175448-3115f89c4b99 // indirect golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 // indirect diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 29676f23..84327087 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strconv" + "strings" "github.com/gravitl/netmaker/config" ) @@ -65,6 +66,12 @@ func GetServerConfig() config.ServerConfig { cfg.Database = GetDB() cfg.Platform = GetPlatform() cfg.Version = GetVersion() + + // == auth config == + var authInfo = GetAuthProviderInfo() + cfg.AuthProvider = authInfo[0] + cfg.ClientID = authInfo[1] + cfg.ClientSecret = authInfo[2] return cfg } func GetAPIConnString() string { @@ -398,6 +405,25 @@ func GetServerCheckinInterval() int64 { return t } +// GetAuthProviderInfo = gets the oauth provider info +func GetAuthProviderInfo() []string { + var authProvider = "" + if os.Getenv("AUTH_PROVIDER") != "" && os.Getenv("CLIENT_ID") != "" && os.Getenv("CLIENT_SECRET") != "" { + authProvider = strings.ToLower(os.Getenv("AUTH_PROVIDER")) + if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" { + return []string{authProvider, os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET")} + } else { + authProvider = "" + } + } else if config.Config.Server.AuthProvider != "" && config.Config.Server.ClientID != "" && config.Config.Server.ClientSecret != "" { + authProvider = strings.ToLower(config.Config.Server.AuthProvider) + if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" { + return []string{authProvider, config.Config.Server.ClientID, config.Config.Server.ClientSecret} + } + } + return []string{"", "", ""} +} + // GetMacAddr - get's mac address func getMacAddr() string { ifas, err := net.Interfaces()