diff --git a/examples/integration/_client.sh b/examples/integration/_client.sh index 8d318f6..98e7eef 100755 --- a/examples/integration/_client.sh +++ b/examples/integration/_client.sh @@ -51,3 +51,6 @@ ssh sshportal -l admin config backup --indent --ignore-events > backup-2 set -xe diff backup-1.clean backup-2.clean ) + +# bastion +# TODO diff --git a/main.go b/main.go index 5cc3bb3..2fd6b8f 100644 --- a/main.go +++ b/main.go @@ -2,19 +2,25 @@ package main import ( "bytes" + "encoding/json" "fmt" + "io" "log" "math/rand" "net" "os" + "os/exec" "path" "strings" + "syscall" "time" + "unsafe" "github.com/gliderlabs/ssh" "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/mysql" _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/kr/pty" "github.com/urfave/cli" gossh "golang.org/x/crypto/ssh" ) @@ -87,6 +93,10 @@ func main() { Usage: "Do not print errors, if any", }, }, + }, { + Name: "_test_server", + Hidden: true, + Action: testServer, }, } if err := app.Run(os.Args); err != nil { @@ -224,3 +234,68 @@ func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error { } return nil } + +// testServer is an hidden handler used for integration tests +func testServer(c *cli.Context) error { + ssh.Handle(func(s ssh.Session) { + helloMsg := struct { + User string + Environ []string + Command []string + }{ + User: s.User(), + Environ: s.Environ(), + Command: s.Command(), + } + enc := json.NewEncoder(s) + if err := enc.Encode(&helloMsg); err != nil { + log.Fatalf("failed to write helloMsg: %v", err) + } + var cmd *exec.Cmd + if s.Command() == nil { + cmd = exec.Command("/bin/sh") // #nosec + } else { + cmd = exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec + } + ptyReq, winCh, isPty := s.Pty() + var cmdErr error + if isPty { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + f, err := pty.Start(cmd) + if err != nil { + fmt.Fprintf(s, "failed to run command: %v\n", err) + _ = s.Exit(1) + return + } + go func() { + for win := range winCh { + _, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), + uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec + } + }() + go func() { + _, _ = io.Copy(f, s) // stdin + }() + _, _ = io.Copy(s, f) // stdout + cmdErr = cmd.Wait() + } else { + //cmd.Stdin = s + cmd.Stdout = s + cmd.Stderr = s + cmdErr = cmd.Run() + } + + if cmdErr != nil { + if exitError, ok := cmdErr.(*exec.ExitError); ok { + waitStatus := exitError.Sys().(syscall.WaitStatus) + _ = s.Exit(waitStatus.ExitStatus()) + return + } + } + waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus) + _ = s.Exit(waitStatus.ExitStatus()) + }) + + log.Println("starting ssh server on port 2222...") + return ssh.ListenAndServe(":2222", nil) +}