From b1b497faa4964ead1cd1a3f2423a8767a5d53081 Mon Sep 17 00:00:00 2001 From: 0xdcarns Date: Thu, 15 Sep 2022 10:23:19 -0400 Subject: [PATCH] PR comments addressed --- auth/google.go | 2 +- auth/nodecallback.go | 2 +- auth/nodesession.go | 7 +++---- auth/oidc.go | 1 - controllers/user.go | 6 +++++- ee/initialize.go | 2 +- ee/util.go | 10 +++++----- mq/publishers.go | 1 - netclient/functions/join.go | 21 +++++++++++---------- netclient/functions/mqpublish.go | 2 +- 10 files changed, 28 insertions(+), 26 deletions(-) diff --git a/auth/google.go b/auth/google.go index 22be2a47..3b614481 100644 --- a/auth/google.go +++ b/auth/google.go @@ -88,7 +88,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { } logger.Log(1, "completed google OAuth sigin in for", content.Email) - http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect) + http.Redirect(w, r, fmt.Sprintf("%s/login?login=%s&user=%s", servercfg.GetFrontendURL(), jwt, content.Email), http.StatusPermanentRedirect) } func getGoogleUserInfo(state string, code string) (*OAuthUser, error) { diff --git a/auth/nodecallback.go b/auth/nodecallback.go index d7bfae9c..4af44ca5 100644 --- a/auth/nodecallback.go +++ b/auth/nodecallback.go @@ -58,7 +58,7 @@ func HandleNodeSSOCallback(w http.ResponseWriter, r *http.Request) { // retrieve machinekey from state cache reqKeyIf, machineKeyFoundErr := netcache.Get(state) if machineKeyFoundErr != nil { - logger.Log(0, "requested machine state key expired before authorisation completed -", err.Error()) + logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error()) reqKeyIf = &netcache.CValue{ Network: "invalid", Value: state, diff --git a/auth/nodesession.go b/auth/nodesession.go index b848c2b8..0b4cd3cc 100644 --- a/auth/nodesession.go +++ b/auth/nodesession.go @@ -19,7 +19,7 @@ import ( // SessionHandler - called by the HTTP router when user // is calling netclient with --login-server parameter in order to authenticate // via SSO mechanism by OAuth2 protocol flow. -// This triggers a session start and it is managed by the flow implmented here and callback +// This triggers a session start and it is managed by the flow implemented here and callback // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured func SessionHandler(conn *websocket.Conn) { defer conn.Close() @@ -55,6 +55,8 @@ func SessionHandler(conn *websocket.Conn) { // TBD: what should be the timeout here ? timeout := make(chan bool, 1) answer := make(chan string, 1) + defer close(answer) + defer close(timeout) if loginMessage.User != "" { // handle basic auth // verify that server supports basic auth, then authorize the request with given credentials @@ -149,7 +151,4 @@ func SessionHandler(conn *websocket.Conn) { logger.Log(0, "write close:", err.Error()) return } - time.After(time.Second) - close(answer) - close(timeout) } diff --git a/auth/oidc.go b/auth/oidc.go index ad03b1c9..11115728 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -62,7 +62,6 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect) return } - logger.Log(3, "using state string:", oauth_state_string) var url = auth_provider.AuthCodeURL(oauth_state_string) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } diff --git a/controllers/user.go b/controllers/user.go index f53ab2f1..1f72b3ff 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -476,7 +476,11 @@ func socketHandler(w http.ResponseWriter, r *http.Request) { // Upgrade our raw HTTP connection to a websocket based one conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - logger.Log(0, "error during connection upgrade for node SSO sign-in:", err.Error()) + logger.Log(0, "error during connection upgrade for node sign-in:", err.Error()) + return + } + if conn == nil { + logger.Log(0, "failed to establish web-socket connection during node sign-in") return } // Start handling the session diff --git a/ee/initialize.go b/ee/initialize.go index 665e7729..9439d6ac 100644 --- a/ee/initialize.go +++ b/ee/initialize.go @@ -13,7 +13,7 @@ import ( // InitEE - Initialize EE Logic func InitEE() { - SetIsEnterprise() + setIsEnterprise() models.SetLogo(retrieveEELogo()) controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers) logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { diff --git a/ee/util.go b/ee/util.go index 90705a0b..7e923aaf 100644 --- a/ee/util.go +++ b/ee/util.go @@ -8,16 +8,16 @@ import ( var isEnterprise bool -// SetIsEnterprise - sets server to use enterprise features -func SetIsEnterprise() { - isEnterprise = true -} - // IsEnterprise - checks if enterprise binary or not func IsEnterprise() bool { return isEnterprise } +// setIsEnterprise - sets server to use enterprise features +func setIsEnterprise() { + isEnterprise = true +} + // base64encode - base64 encode helper function func base64encode(input []byte) string { return base64.StdEncoding.EncodeToString(input) diff --git a/mq/publishers.go b/mq/publishers.go index 9e6cc521..4c45ea7b 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -229,7 +229,6 @@ func collectServerMetrics(networks []models.Network) { func pushMetricsToExporter(metrics models.Metrics) error { logger.Log(2, "----> Pushing metrics to exporter") - SetupMQTT() data, err := json.Marshal(metrics) if err != nil { return errors.New("failed to marshal metrics: " + err.Error()) diff --git a/netclient/functions/join.go b/netclient/functions/join.go index fc53e295..c59553be 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -13,7 +13,6 @@ import ( "runtime" "strings" "syscall" - "time" "github.com/gorilla/websocket" "github.com/gravitl/netmaker/logger" @@ -56,7 +55,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { // Dial the netmaker server controller conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil) if err != nil { - logger.Log(0, fmt.Sprintf("Error connecting to %s : %s", cfg.Server.API, err.Error())) + logger.Log(0, fmt.Sprintf("error connecting to %s : %s", cfg.Server.API, err.Error())) return err } // Don't forget to close when finished @@ -113,14 +112,14 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { // An answer from the server. // Server waits ~5 min - If takes too long timeout will be triggered by the server done := make(chan struct{}) + defer close(done) // Following code will run in a separate go routine // it reads a message from the server which either contains 'AccessToken:' string or not // if not - then it contains an Error to display. // if yes - then AccessToken is to be used to proceed joining the network go func() { - defer close(done) for { - _, msg, err := conn.ReadMessage() + msgType, msg, err := conn.ReadMessage() if err != nil { // Error reading a message from the server if !strings.Contains(err.Error(), "normal") { @@ -128,13 +127,19 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { } return } + + if msgType == websocket.CloseMessage { + logger.Log(1, "received close message from server") + done <- struct{}{} + return + } // Get the access token from the response if strings.Contains(string(msg), "AccessToken: ") { // Access was granted rxToken := strings.TrimPrefix(string(msg), "AccessToken: ") accesstoken, err := config.ParseAccessToken(rxToken) if err != nil { - log.Printf("Failed to parse received access token %s,err=%s\n", accesstoken, err.Error()) + logger.Log(0, fmt.Sprintf("failed to parse received access token %s,err=%s\n", accesstoken, err.Error())) return } @@ -159,7 +164,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { logger.Log(1, "finished") return nil case <-interrupt: - log.Println("interrupt") + logger.Log(0, "interrupt received, closing connection") // Cleanly close the connection by sending a close message and then // waiting (with timeout) for the server to close the connection. err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -167,10 +172,6 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { logger.Log(0, "write close:", err.Error()) return err } - select { - case <-done: - case <-time.After(time.Second): - } return nil } } diff --git a/netclient/functions/mqpublish.go b/netclient/functions/mqpublish.go index 766c19ed..8341a4fb 100644 --- a/netclient/functions/mqpublish.go +++ b/netclient/functions/mqpublish.go @@ -167,7 +167,7 @@ func publishMetrics(nodeCfg *config.ClientConfig) { logger.Log(1, "failed to authenticate when publishing metrics", err.Error()) return } - url := "https://" + nodeCfg.Server.API + "/api/nodes/" + nodeCfg.Network + "/" + nodeCfg.Node.ID + url := fmt.Sprintf("https://%s/api/nodes/%s/%s", nodeCfg.Server.API, nodeCfg.Network, nodeCfg.Node.ID) response, err := API("", http.MethodGet, url, token) if err != nil { logger.Log(1, "failed to read from server during metrics publish", err.Error())