diff --git a/main.go b/main.go index d545347..7075eb4 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "errors" "fmt" - "io" "log" "os" "path" @@ -83,22 +82,32 @@ func server(c *cli.Context) error { } ssh.Handle(func(s ssh.Session) { - log.Printf("New connection: user=%q remote=%q local=%q command=%q", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command()) + currentUser := s.Context().Value(userContextKey).(User) + log.Printf("New connection: sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email) + + if currentUser.ID < 1 { + fmt.Fprintf(s, "You are not authorized to access this server.\n") + return + } switch s.User() { case c.String("config-user"): + if !currentUser.IsAdmin { + fmt.Fprintf(s, "You are not an administrator.\n") + return + } if err := shell(c, s, s.Command(), db); err != nil { - io.WriteString(s, fmt.Sprintf("error: %v\n", err)) + fmt.Fprintf(s, "error: %v\n", err) } default: host, err := RemoteHostFromSession(s, db) if err != nil { - io.WriteString(s, fmt.Sprintf("error: %v\n", err)) + fmt.Fprintf(s, "error: %v\n", err) // FIXME: print available hosts return } if err := proxy(s, host); err != nil { - io.WriteString(s, fmt.Sprintf("error: %v\n", err)) + fmt.Fprintf(s, "error: %v\n", err) } } }) @@ -144,7 +153,8 @@ func server(c *cli.Context) error { return true } - return false + // always returning true to display a custom message for invalid users + return true })) log.Printf("SSH Server accepting connections on %s", c.String("bind-address"))