diff --git a/auth/nodecallback.go b/auth/nodecallback.go index 4af44ca5..5345f111 100644 --- a/auth/nodecallback.go +++ b/auth/nodecallback.go @@ -155,8 +155,11 @@ func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) [] // Listens in /oidc/register/:regKey. func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) { - logger.Log(1, "RegisterNodeSSO\n") - + if auth_provider == nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("invalid login attempt")) + return + } vars := mux.Vars(r) // machineKeyStr this is not key but state @@ -165,8 +168,7 @@ func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) { if machineKeyStr == "" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Wrong params")) - logger.Log(0, "Wrong params ", machineKeyStr) + w.Write([]byte("invalid login attempt")) return } diff --git a/auth/nodesession.go b/auth/nodesession.go index 0b4cd3cc..af6107a3 100644 --- a/auth/nodesession.go +++ b/auth/nodesession.go @@ -23,7 +23,6 @@ import ( // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured func SessionHandler(conn *websocket.Conn) { defer conn.Close() - logger.Log(1, "Running sessionHandler") // If reached here we have a session from user to handle... messageType, message, err := conn.ReadMessage() @@ -58,12 +57,20 @@ func SessionHandler(conn *websocket.Conn) { defer close(answer) defer close(timeout) + if _, err = logic.GetNetwork(loginMessage.Network); err != nil { + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + return + } + if loginMessage.User != "" { // handle basic auth // verify that server supports basic auth, then authorize the request with given credentials // check if user is allowed to join via node sso // i.e. user is admin or user has network permissions if !servercfg.IsBasicAuthEnabled() { - err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled")) + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logger.Log(0, "error during message writing:", err.Error()) } @@ -73,7 +80,7 @@ func SessionHandler(conn *websocket.Conn) { Password: loginMessage.Password, }) if err != nil { - err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User))) + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logger.Log(0, "error during message writing:", err.Error()) } @@ -81,7 +88,7 @@ func SessionHandler(conn *websocket.Conn) { } user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false) if err != nil { - err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User))) + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logger.Log(0, "error during message writing:", err.Error()) } @@ -99,6 +106,13 @@ func SessionHandler(conn *websocket.Conn) { return } } else { // handle SSO / OAuth + if auth_provider == nil { + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + logger.Log(0, "error during message writing:", err.Error()) + } + return + } redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr) err = conn.WriteMessage(messageType, []byte(redirectUrl)) if err != nil { @@ -135,7 +149,7 @@ func SessionHandler(conn *websocket.Conn) { case <-timeout: logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network) // the read from req.answerCh has timed out - err = conn.WriteMessage(messageType, []byte("Authentication server time out")) + err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logger.Log(0, "Error during message writing:", err.Error()) } diff --git a/logger/logger.go b/logger/logger.go index 5b90b6fa..ad258772 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -134,7 +134,7 @@ func Retrieve(filePath string) string { // FatalLog - exits os after logging func FatalLog(message ...string) { - fmt.Printf("[netmaker] Fatal: %s \n", MakeString(" ", message...)) + fmt.Printf("[%s] Fatal: %s \n", program, MakeString(" ", message...)) os.Exit(2) } diff --git a/netclient/command/commands.go b/netclient/command/commands.go index b8c7d082..7bb92f0a 100644 --- a/netclient/command/commands.go +++ b/netclient/command/commands.go @@ -30,14 +30,13 @@ func Join(cfg *config.ClientConfig, privateKey string) error { logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer) err = functions.JoinViaSSo(cfg, privateKey) if err != nil { - logger.Log(0, "Join via OIDC failed: ", err.Error()) + logger.Log(0, "Join failed: ", err.Error()) return err } if cfg.AccessKey == "" { - return errors.New("failed to get access key") + return errors.New("login failed") } - logger.Log(1, "Got an access key to ", cfg.Network, " via:", cfg.SsoServer) } logger.Log(1, "Joining network: ", cfg.Network) diff --git a/netclient/daemon/freebsd.go b/netclient/daemon/freebsd.go index 6eafcfb6..27238412 100644 --- a/netclient/daemon/freebsd.go +++ b/netclient/daemon/freebsd.go @@ -28,7 +28,7 @@ func SetupFreebsdDaemon() error { } err = ncutils.Copy(binarypath, EXEC_DIR+"netclient") if err != nil { - log.Println(err) + logger.Log(0, err.Error()) return err } diff --git a/netclient/daemon/macos.go b/netclient/daemon/macos.go index 8301ed3b..d59ccf65 100644 --- a/netclient/daemon/macos.go +++ b/netclient/daemon/macos.go @@ -25,7 +25,7 @@ func SetupMacDaemon() error { } err = ncutils.Copy(binarypath, MAC_EXEC_DIR+"netclient") if err != nil { - log.Println(err) + logger.Log(0, err.Error()) return err } diff --git a/netclient/daemon/systemd.go b/netclient/daemon/systemd.go index aa218734..5916c6c0 100644 --- a/netclient/daemon/systemd.go +++ b/netclient/daemon/systemd.go @@ -38,7 +38,7 @@ func SetupSystemDDaemon() error { } err = ncutils.Copy(binarypath, EXEC_DIR+"netclient") if err != nil { - log.Println(err) + logger.Log(0, err.Error()) return err } @@ -64,7 +64,7 @@ WantedBy=multi-user.target if !ncutils.FileExists("/etc/systemd/system/netclient.service") { err = os.WriteFile("/etc/systemd/system/netclient.service", servicebytes, 0644) if err != nil { - log.Println(err) + logger.Log(0, err.Error()) return err } } @@ -106,7 +106,7 @@ func RemoveSystemDServices() error { var err error if !ncutils.IsWindows() && isOnlyService() { if err != nil { - log.Println(err) + logger.Log(0, err.Error()) } ncutils.RunCmd("systemctl disable netclient.service", false) ncutils.RunCmd("systemctl disable netclient.timer", false) diff --git a/netclient/functions/common.go b/netclient/functions/common.go index a8b23c80..7b78d7ba 100644 --- a/netclient/functions/common.go +++ b/netclient/functions/common.go @@ -301,8 +301,7 @@ func WipeLocal(cfg *config.ClientConfig) error { if cfg.Node.Interface != "" { if ncutils.FileExists(dir + cfg.Node.Interface + ".conf") { if err := os.Remove(dir + cfg.Node.Interface + ".conf"); err != nil { - log.Println("error removing .conf:") - log.Println(err.Error()) + logger.Log(0, err.Error()) fail = true } } diff --git a/netclient/functions/join.go b/netclient/functions/join.go index c59553be..c6684a02 100644 --- a/netclient/functions/join.go +++ b/netclient/functions/join.go @@ -82,6 +82,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { } loginMsg.User = global_settings.User loginMsg.Password = string(pass) + fmt.Println("attempting login...") } msgTx, err := json.Marshal(loginMsg) @@ -101,7 +102,6 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { // Wait to receive something from server _, msg, err := conn.ReadMessage() if err != nil { - log.Println("Error in receive:", err) return err } // Print message from the netmaker controller to the user @@ -121,6 +121,11 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error { for { msgType, msg, err := conn.ReadMessage() if err != nil { + if msgType < 0 { + logger.Log(1, "received close message from server") + done <- struct{}{} + return + } // Error reading a message from the server if !strings.Contains(err.Error(), "normal") { logger.Log(0, "read:", err.Error()) diff --git a/netclient/main.go b/netclient/main.go index de79967b..8987e3b8 100644 --- a/netclient/main.go +++ b/netclient/main.go @@ -4,10 +4,10 @@ package main import ( - "log" "os" "runtime/debug" + "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/netclient/cli_options" "github.com/gravitl/netmaker/netclient/config" "github.com/gravitl/netmaker/netclient/functions" @@ -47,7 +47,7 @@ func main() { } else { err := app.Run(os.Args) if err != nil { - log.Fatal(err) + logger.FatalLog(err.Error()) } } }