From 8aa185d88000fa8c6b2513e719d4890cd5652fd1 Mon Sep 17 00:00:00 2001 From: Aceix Date: Wed, 13 Dec 2023 10:04:09 +0000 Subject: [PATCH] feat(NET-678): add saas support to nmctl (#2687) * feat(NET-678): add saas support to nmctl * fix(NET-678): fix context endpoint for sso --- cli/cmd/context/set.go | 33 +++- cli/config/config.go | 2 + cli/functions/http_client.go | 285 ++++++++++++++++++++++++++++++++--- models/structs.go | 51 +++++++ 4 files changed, 344 insertions(+), 27 deletions(-) diff --git a/cli/cmd/context/set.go b/cli/cmd/context/set.go index ddd5bb96..ba3f4de6 100644 --- a/cli/cmd/context/set.go +++ b/cli/cmd/context/set.go @@ -1,9 +1,11 @@ package context import ( + "fmt" "log" "github.com/gravitl/netmaker/cli/config" + "github.com/gravitl/netmaker/cli/functions" "github.com/spf13/cobra" ) @@ -13,6 +15,8 @@ var ( password string masterKey string sso bool + tenantId string + saas bool ) var contextSetCmd = &cobra.Command{ @@ -27,10 +31,28 @@ var contextSetCmd = &cobra.Command{ Password: password, MasterKey: masterKey, SSO: sso, + TenantId: tenantId, + Saas: saas, } - if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO { - cmd.Usage() - log.Fatal("Either username/password or master key is required") + if !ctx.Saas { + if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO { + log.Fatal("Either username/password or master key is required") + cmd.Usage() + } + if ctx.Endpoint == "" { + log.Fatal("Endpoint is required when for self-hosted tenants") + cmd.Usage() + } + } else { + if ctx.TenantId == "" { + log.Fatal("Tenant ID is required for SaaS tenants") + cmd.Usage() + } + ctx.Endpoint = fmt.Sprintf(functions.TenantUrlTemplate, tenantId) + if ctx.Username == "" && ctx.Password == "" && !ctx.SSO { + log.Fatal("Username/password is required for non-SSO SaaS contexts") + cmd.Usage() + } } config.SetContext(args[0], ctx) }, @@ -38,11 +60,12 @@ var contextSetCmd = &cobra.Command{ func init() { contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server") - contextSetCmd.MarkFlagRequired("endpoint") contextSetCmd.Flags().StringVar(&username, "username", "", "Username") contextSetCmd.Flags().StringVar(&password, "password", "", "Password") contextSetCmd.MarkFlagsRequiredTogether("username", "password") - contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO) ?") + contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO)?") contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key") + contextSetCmd.Flags().StringVar(&tenantId, "tenant_id", "", "Tenant ID") + contextSetCmd.Flags().BoolVar(&saas, "saas", false, "Is this context for a SaaS tenant?") rootCmd.AddCommand(contextSetCmd) } diff --git a/cli/config/config.go b/cli/config/config.go index 8335f8d7..ba32c48d 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -18,6 +18,8 @@ type Context struct { Current bool `yaml:"current,omitempty"` AuthToken string `yaml:"auth_token,omitempty"` SSO bool `yaml:"sso,omitempty"` + TenantId string `yaml:"tenant_id,omitempty"` + Saas bool `yaml:"saas,omitempty"` } var ( diff --git a/cli/functions/http_client.go b/cli/functions/http_client.go index c4c984f3..fdb4210a 100644 --- a/cli/functions/http_client.go +++ b/cli/functions/http_client.go @@ -11,11 +11,19 @@ import ( "os" "os/signal" "strings" + "time" "github.com/gorilla/websocket" "github.com/gravitl/netmaker/cli/config" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" + "golang.org/x/exp/slog" +) + +const ( + ambBaseUrl = "https://api.accounts.netmaker.io" + TenantUrlTemplate = "https://api-%s.app.prod.netmaker.io" + ambOauthWssUrl = "wss://api.accounts.netmaker.io/api/v1/auth/sso" ) func ssoLogin(endpoint string) string { @@ -81,34 +89,57 @@ func getAuthToken(ctx config.Context, force bool) string { if !force && ctx.AuthToken != "" { return ctx.AuthToken } - if ctx.SSO { - authToken := ssoLogin(ctx.Endpoint) + if !ctx.Saas { + if ctx.SSO { + authToken := ssoLogin(ctx.Endpoint) + config.SetAuthToken(authToken) + return authToken + } + authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password} + payload, err := json.Marshal(authParams) + if err != nil { + log.Fatal(err) + } + res, err := http.Post(ctx.Endpoint+"/api/users/adm/authenticate", "application/json", bytes.NewReader(payload)) + if err != nil { + log.Fatal(err) + } + defer res.Body.Close() + resBodyBytes, err := io.ReadAll(res.Body) + if err != nil { + log.Fatalf("Client could not read response body: %s", err) + } + if res.StatusCode != http.StatusOK { + log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes)) + } + body := new(models.SuccessResponse) + if err := json.Unmarshal(resBodyBytes, body); err != nil { + log.Fatalf("Error unmarshalling JSON: %s", err) + } + authToken := body.Response.(map[string]any)["AuthToken"].(string) config.SetAuthToken(authToken) return authToken } - authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password} - payload, err := json.Marshal(authParams) + + if !ctx.SSO { + sToken, _, err := basicAuthSaasSignin(ctx.Username, ctx.Password) + if err != nil { + log.Fatal(err) + } + authToken, _, err := tenantLogin(ctx, sToken) + if err != nil { + log.Fatal(err) + } + config.SetAuthToken(authToken) + return authToken + } + + accessToken, err := loginSaaSOauth(&models.SsoLoginReqDto{OauthProvider: "oidc"}, ctx.TenantId) if err != nil { log.Fatal(err) } - res, err := http.Post(ctx.Endpoint+"/api/users/adm/authenticate", "application/json", bytes.NewReader(payload)) - if err != nil { - log.Fatal(err) - } - resBodyBytes, err := io.ReadAll(res.Body) - if err != nil { - log.Fatalf("Client could not read response body: %s", err) - } - if res.StatusCode != http.StatusOK { - log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes)) - } - body := new(models.SuccessResponse) - if err := json.Unmarshal(resBodyBytes, body); err != nil { - log.Fatalf("Error unmarshalling JSON: %s", err) - } - authToken := body.Response.(map[string]any)["AuthToken"].(string) - config.SetAuthToken(authToken) - return authToken + config.SetAuthToken(accessToken) + return accessToken } func request[T any](method, route string, payload any) *T { @@ -188,3 +219,213 @@ func get(route string) string { } return string(bodyBytes) } + +func basicAuthSaasSignin(email, password string) (string, http.Header, error) { + payload := models.SignInReqDto{ + FormFields: []models.FormField{ + { + Id: "email", + Value: email, + }, + { + Id: "password", + Value: password, + }, + }, + } + + var res models.SignInResDto + + // Create a new HTTP client with a timeout + client := &http.Client{ + Timeout: 30 * time.Second, + } + + // Create the request body + payloadBuf := new(bytes.Buffer) + json.NewEncoder(payloadBuf).Encode(payload) + + // Create the request + req, err := http.NewRequest("POST", ambBaseUrl+"/auth/signin", payloadBuf) + if err != nil { + return "", http.Header{}, err + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("rid", "thirdpartyemailpassword") + + // Send the request + resp, err := client.Do(req) + if err != nil { + return "", http.Header{}, err + } + defer resp.Body.Close() + + // Check the response status code + if resp.StatusCode != http.StatusOK { + return "", http.Header{}, fmt.Errorf("error authenticating: %s", resp.Status) + } + + // Copy the response headers + resHeaders := resp.Header + + // Decode the response body + err = json.NewDecoder(resp.Body).Decode(&res) + if err != nil { + return "", http.Header{}, err + } + + sToken := resHeaders.Get(models.ResHeaderKeyStAccessToken) + encodedAccessToken := url.QueryEscape(sToken) + + return encodedAccessToken, resHeaders, nil +} + +func tenantLogin(ctx config.Context, sToken string) (string, string, error) { + url := fmt.Sprintf("%s/api/v1/tenant/login?tenant_id=%s", ambBaseUrl, ctx.TenantId) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, url, nil) + + if err != nil { + return "", "", err + } + req.Header.Add("Cookie", fmt.Sprintf("sAccessToken=%s", sToken)) + + res, err := client.Do(req) + if err != nil { + return "", "", err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return "", "", err + } + + data := models.TenantLoginResDto{} + json.Unmarshal(body, &data) + + return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, ctx.TenantId), nil +} + +func loginSaaSOauth(payload *models.SsoLoginReqDto, tenantId string) (string, error) { + socketUrl := ambOauthWssUrl + // Dial the netmaker server controller + conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil) + if err != nil { + slog.Error("error connecting to endpoint ", "url", socketUrl, "err", err) + return "", err + } + + defer conn.Close() + return handleServerSSORegisterConn(payload, conn, tenantId) +} + +func handleServerSSORegisterConn(payload *models.SsoLoginReqDto, conn *websocket.Conn, tenantId string) (string, error) { + reqData, err := json.Marshal(payload) + if err != nil { + return "", err + } + if err := conn.WriteMessage(websocket.TextMessage, reqData); err != nil { + return "", err + } + dataCh := make(chan string) + defer close(dataCh) + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + go func() { + for { + msgType, msg, err := conn.ReadMessage() + if err != nil { + if msgType < 0 { + slog.Info("received close message from server") + return + } + if !strings.Contains(err.Error(), "normal") { // Error reading a message from the server + slog.Error("error msg", "err", err) + } + return + } + if msgType == websocket.CloseMessage { + slog.Info("received close message from server") + return + } + if strings.Contains(string(msg), "auth/sso") { + fmt.Printf("Please visit:\n %s \nto authenticate\n", string(msg)) + } else { + var res models.SsoLoginData + if err := json.Unmarshal(msg, &res); err != nil { + return + } + accessToken, _, err := tenantLoginV2(res.AmbAccessToken, tenantId, res.Username) + if err != nil { + slog.Error("error logging in tenant", "err", err) + dataCh <- "" + return + } + dataCh <- accessToken + return + } + } + }() + + for { + select { + case accessToken := <-dataCh: + if accessToken == "" { + slog.Info("error getting access token") + return "", fmt.Errorf("error getting access token") + } + return accessToken, nil + case <-time.After(30 * time.Second): + slog.Error("authentiation timed out") + os.Exit(1) + case <-interrupt: + slog.Info("interrupt received, closing connection") + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Fatal(err) + } + os.Exit(1) + } + } +} + +func tenantLoginV2(ambJwt, tenantId, email string) (string, string, error) { + url := fmt.Sprintf("%s/api/v1/tenant/login/custom", ambBaseUrl) + payload := models.LoginReqDto{ + Email: email, + TenantID: tenantId, + } + payloadBuf := new(bytes.Buffer) + json.NewEncoder(payloadBuf).Encode(payload) + + client := &http.Client{} + req, err := http.NewRequest("POST", url, payloadBuf) + if err != nil { + slog.Error("error creating request", "err", err) + return "", "", err + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", ambJwt)) + + res, err := client.Do(req) + if err != nil { + slog.Error("error sending request", "err", err) + return "", "", err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + slog.Error("error reading response body", "err", err) + return "", "", err + } + + data := models.TenantLoginResDto{} + json.Unmarshal(body, &data) + + return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, tenantId), nil +} diff --git a/models/structs.go b/models/structs.go index 8ce29ced..66abc500 100644 --- a/models/structs.go +++ b/models/structs.go @@ -307,3 +307,54 @@ type LicenseLimits struct { Clients int `json:"clients"` Networks int `json:"networks"` } + +type SignInReqDto struct { + FormFields FormFields `json:"formFields"` +} + +type FormField struct { + Id string `json:"id"` + Value any `json:"value"` +} + +type FormFields []FormField + +type SignInResDto struct { + Status string `json:"status"` + User User `json:"user"` +} + +type TenantLoginResDto struct { + Code int `json:"code"` + Message string `json:"message"` + Response struct { + UserName string `json:"UserName"` + AuthToken string `json:"AuthToken"` + } `json:"response"` +} + +type SsoLoginReqDto struct { + OauthProvider string `json:"oauthprovider"` +} + +type SsoLoginResDto struct { + User string `json:"UserName"` + AuthToken string `json:"AuthToken"` +} + +type SsoLoginData struct { + Expiration time.Time `json:"expiration"` + OauthProvider string `json:"oauthprovider,omitempty"` + OauthCode string `json:"oauthcode,omitempty"` + Username string `json:"username,omitempty"` + AmbAccessToken string `json:"ambaccesstoken,omitempty"` +} + +type LoginReqDto struct { + Email string `json:"email"` + TenantID string `json:"tenant_id"` +} + +const ( + ResHeaderKeyStAccessToken = "St-Access-Token" +)