sshportal/main.go

307 lines
7.1 KiB
Go
Raw Normal View History

2017-09-30 19:12:43 +08:00
package main
import (
2018-01-01 17:41:21 +08:00
"bytes"
2018-01-02 17:57:18 +08:00
"encoding/json"
2017-09-30 19:12:43 +08:00
"fmt"
2018-01-02 17:57:18 +08:00
"io"
2017-09-30 19:12:43 +08:00
"log"
2017-11-14 08:29:25 +08:00
"math/rand"
2018-01-01 17:41:21 +08:00
"net"
2017-09-30 19:12:43 +08:00
"os"
2018-01-02 17:57:18 +08:00
"os/exec"
2017-09-30 19:12:43 +08:00
"path"
2017-11-08 02:40:14 +08:00
"strings"
2018-01-02 17:57:18 +08:00
"syscall"
2017-11-14 08:29:25 +08:00
"time"
2018-01-02 17:57:18 +08:00
"unsafe"
2017-09-30 19:12:43 +08:00
"github.com/gliderlabs/ssh"
2017-10-30 23:48:14 +08:00
"github.com/jinzhu/gorm"
2017-10-31 00:12:04 +08:00
_ "github.com/jinzhu/gorm/dialects/mysql"
2017-10-30 23:48:14 +08:00
_ "github.com/jinzhu/gorm/dialects/sqlite"
2018-01-02 17:57:18 +08:00
"github.com/kr/pty"
2017-09-30 19:12:43 +08:00
"github.com/urfave/cli"
2017-11-04 05:54:16 +08:00
gossh "golang.org/x/crypto/ssh"
2017-09-30 19:12:43 +08:00
)
2017-11-14 07:38:23 +08:00
var (
2017-12-04 01:18:17 +08:00
// Version should be updated by hand at each release
2018-01-03 07:27:07 +08:00
Version = "1.7.1+dev"
2017-12-04 01:18:17 +08:00
// GitTag will be overwritten automatically by the build system
GitTag string
// GitSha will be overwritten automatically by the build system
GitSha string
// GitBranch will be overwritten automatically by the build system
GitBranch string
2017-11-14 07:38:23 +08:00
)
2017-11-02 05:09:08 +08:00
2017-09-30 19:12:43 +08:00
func main() {
2017-11-14 08:29:25 +08:00
rand.Seed(time.Now().UnixNano())
2017-09-30 19:12:43 +08:00
app := cli.NewApp()
app.Name = path.Base(os.Args[0])
app.Author = "Manfred Touron"
2017-12-04 01:18:17 +08:00
app.Version = Version + " (" + GitSha + ")"
2017-09-30 19:12:43 +08:00
app.Email = "https://github.com/moul/sshportal"
app.Commands = []cli.Command{
{
Name: "server",
Usage: "Start sshportal server",
Action: server,
Flags: []cli.Flag{
cli.StringFlag{
Name: "bind-address, b",
EnvVar: "SSHPORTAL_BIND",
Value: ":2222",
Usage: "SSH server bind address",
},
cli.StringFlag{
Name: "db-driver",
Value: "sqlite3",
Usage: "GORM driver (sqlite3)",
},
cli.StringFlag{
Name: "db-conn",
Value: "./sshportal.db",
Usage: "GORM connection string",
},
cli.BoolFlag{
Name: "debug, D",
Usage: "Display debug information",
},
cli.StringFlag{
Name: "aes-key",
Usage: "Encrypt sensitive data in database (length: 16, 24 or 32)",
},
2018-01-02 23:31:34 +08:00
cli.StringFlag{
Name: "logs-location",
2018-01-04 18:45:05 +08:00
Value: "./sshportal",
2018-01-02 23:31:34 +08:00
Usage: "Store user session files",
},
},
2018-01-01 17:41:21 +08:00
}, {
Name: "healthcheck",
Action: healthcheck,
Flags: []cli.Flag{
cli.StringFlag{
Name: "addr, a",
Value: "localhost:2222",
Usage: "sshportal server address",
},
cli.BoolFlag{
Name: "wait, w",
Usage: "Loop indefinitely until sshportal is ready",
},
cli.BoolFlag{
Name: "quiet, q",
Usage: "Do not print errors, if any",
2018-01-01 17:41:21 +08:00
},
},
2018-01-02 17:57:18 +08:00
}, {
Name: "_test_server",
Hidden: true,
Action: testServer,
2017-11-24 21:29:41 +08:00
},
2017-09-30 19:12:43 +08:00
}
if err := app.Run(os.Args); err != nil {
log.Fatalf("error: %v", err)
}
2017-09-30 19:12:43 +08:00
}
func server(c *cli.Context) error {
2017-11-24 21:29:41 +08:00
switch len(c.String("aes-key")) {
case 0, 16, 24, 32:
default:
return fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
}
2017-11-19 08:18:17 +08:00
// db
2017-11-29 17:28:33 +08:00
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
2017-10-30 23:48:14 +08:00
if err != nil {
return err
}
2017-12-04 01:18:17 +08:00
defer func() {
if err2 := db.Close(); err2 != nil {
panic(err2)
}
}()
2017-11-19 08:18:17 +08:00
if err = db.DB().Ping(); err != nil {
return err
}
2017-10-31 00:12:04 +08:00
if c.Bool("debug") {
db.LogMode(true)
}
if err = dbInit(db); err != nil {
2017-10-30 23:48:14 +08:00
return err
}
2017-09-30 19:12:43 +08:00
opts := []ssh.Option{}
// custom PublicKeyAuth handler
opts = append(opts, ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)))
opts = append(opts, ssh.PasswordAuth(passwordAuthHandler(db, c)))
2018-01-02 23:31:34 +08:00
// retrieve sshportal SSH private key from database
2017-11-04 05:54:16 +08:00
opts = append(opts, func(srv *ssh.Server) error {
2017-11-23 16:58:32 +08:00
var key SSHKey
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
2017-11-04 05:54:16 +08:00
return err
}
2017-11-29 17:28:33 +08:00
SSHKeyDecrypt(c.String("aes-key"), &key)
var signer gossh.Signer
signer, err = gossh.ParsePrivateKey([]byte(key.PrivKey))
2017-11-04 05:54:16 +08:00
if err != nil {
return err
}
srv.AddHostKey(signer)
return nil
})
// create TCP listening socket
ln, err := net.Listen("tcp", c.String("bind-address"))
if err != nil {
return err
}
// configure server
srv := &ssh.Server{
Addr: c.String("bind-address"),
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", Version),
ChannelHandler: channelHandler,
}
for _, opt := range opts {
if err := srv.SetOption(opt); err != nil {
return err
}
}
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
return srv.Serve(ln)
2017-09-30 19:12:43 +08:00
}
2018-01-01 17:41:21 +08:00
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
func healthcheck(c *cli.Context) error {
config := gossh.ClientConfig{
User: "healthcheck",
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
}
if c.Bool("wait") {
for {
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if !c.Bool("quiet") {
log.Printf("error: %v", err)
}
time.Sleep(time.Second)
continue
}
return nil
}
}
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if c.Bool("quiet") {
return cli.NewExitError("", 1)
}
return err
}
return nil
}
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
client, err := gossh.Dial("tcp", addr, &config)
2018-01-01 17:41:21 +08:00
if err != nil {
return err
}
session, err := client.NewSession()
if err != nil {
return err
}
defer func() {
if err := session.Close(); err != nil {
if !quiet {
log.Printf("failed to close session: %v", err)
}
2018-01-01 17:41:21 +08:00
}
}()
var b bytes.Buffer
session.Stdout = &b
if err := session.Run(""); err != nil {
return err
}
stdout := strings.TrimSpace(b.String())
if stdout != "OK" {
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
}
return nil
}
2018-01-02 17:57:18 +08:00
// testServer is an hidden handler used for integration tests
func testServer(c *cli.Context) error {
ssh.Handle(func(s ssh.Session) {
helloMsg := struct {
User string
Environ []string
Command []string
}{
User: s.User(),
Environ: s.Environ(),
Command: s.Command(),
}
enc := json.NewEncoder(s)
if err := enc.Encode(&helloMsg); err != nil {
log.Fatalf("failed to write helloMsg: %v", err)
}
var cmd *exec.Cmd
if s.Command() == nil {
cmd = exec.Command("/bin/sh") // #nosec
} else {
cmd = exec.Command(s.Command()[0], s.Command()[1:]...) // #nosec
}
ptyReq, winCh, isPty := s.Pty()
var cmdErr error
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
fmt.Fprintf(s, "failed to run command: %v\n", err)
_ = s.Exit(1)
return
}
go func() {
for win := range winCh {
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) // #nosec
}
}()
go func() {
_, _ = io.Copy(f, s) // stdin
}()
_, _ = io.Copy(s, f) // stdout
cmdErr = cmd.Wait()
} else {
//cmd.Stdin = s
cmd.Stdout = s
cmd.Stderr = s
cmdErr = cmd.Run()
}
if cmdErr != nil {
if exitError, ok := cmdErr.(*exec.ExitError); ok {
waitStatus := exitError.Sys().(syscall.WaitStatus)
_ = s.Exit(waitStatus.ExitStatus())
return
}
}
waitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus)
_ = s.Exit(waitStatus.ExitStatus())
})
log.Println("starting ssh server on port 2222...")
return ssh.ListenAndServe(":2222", nil)
}