sshportal/main.go

202 lines
5 KiB
Go
Raw Normal View History

package main // import "moul.io/sshportal"
2017-09-30 19:12:43 +08:00
import (
"fmt"
"log"
2018-11-16 02:38:18 +08:00
"math"
2017-11-14 08:29:25 +08:00
"math/rand"
2018-01-01 17:41:21 +08:00
"net"
2017-09-30 19:12:43 +08:00
"os"
"path"
2017-11-14 08:29:25 +08:00
"time"
2017-09-30 19:12:43 +08:00
2017-10-30 23:48:14 +08:00
"github.com/jinzhu/gorm"
2017-10-31 00:12:04 +08:00
_ "github.com/jinzhu/gorm/dialects/mysql"
2017-10-30 23:48:14 +08:00
_ "github.com/jinzhu/gorm/dialects/sqlite"
2018-11-16 17:24:18 +08:00
"github.com/moul/ssh"
2017-09-30 19:12:43 +08:00
"github.com/urfave/cli"
2018-11-16 17:24:18 +08:00
gossh "golang.org/x/crypto/ssh"
2017-09-30 19:12:43 +08:00
)
2017-11-14 07:38:23 +08:00
var (
2017-12-04 01:18:17 +08:00
// Version should be updated by hand at each release
2018-11-18 22:48:42 +08:00
Version = "1.9.0+dev"
2017-12-04 01:18:17 +08:00
// GitTag will be overwritten automatically by the build system
GitTag string
// GitSha will be overwritten automatically by the build system
GitSha string
// GitBranch will be overwritten automatically by the build system
GitBranch string
2017-11-14 07:38:23 +08:00
)
2017-11-02 05:09:08 +08:00
2017-09-30 19:12:43 +08:00
func main() {
2017-11-14 08:29:25 +08:00
rand.Seed(time.Now().UnixNano())
2017-09-30 19:12:43 +08:00
app := cli.NewApp()
app.Name = path.Base(os.Args[0])
app.Author = "Manfred Touron"
2017-12-04 01:18:17 +08:00
app.Version = Version + " (" + GitSha + ")"
app.Email = "https://moul.io/sshportal"
app.Commands = []cli.Command{
{
Name: "server",
Usage: "Start sshportal server",
Action: func(c *cli.Context) error {
if err := ensureLogDirectory(c.String("logs-location")); err != nil {
return err
}
cfg, err := parseServeConfig(c)
if err != nil {
return err
}
return server(cfg)
},
Flags: []cli.Flag{
cli.StringFlag{
Name: "bind-address, b",
EnvVar: "SSHPORTAL_BIND",
Value: ":2222",
Usage: "SSH server bind address",
},
cli.StringFlag{
Name: "db-driver",
EnvVar: "SSHPORTAL_DB_DRIVER",
Value: "sqlite3",
Usage: "GORM driver (sqlite3)",
},
cli.StringFlag{
Name: "db-conn",
EnvVar: "SSHPORTAL_DATABASE_URL",
Value: "./sshportal.db",
Usage: "GORM connection string",
},
cli.BoolFlag{
Name: "debug, D",
EnvVar: "SSHPORTAL_DEBUG",
Usage: "Display debug information",
},
cli.StringFlag{
Name: "aes-key",
EnvVar: "SSHPORTAL_AES_KEY",
Usage: "Encrypt sensitive data in database (length: 16, 24 or 32)",
},
2018-01-02 23:31:34 +08:00
cli.StringFlag{
Name: "logs-location",
EnvVar: "SSHPORTAL_LOGS_LOCATION",
Value: "./log",
Usage: "Store user session files",
2018-01-02 23:31:34 +08:00
},
2018-11-16 02:38:18 +08:00
cli.DurationFlag{
Name: "idle-timeout",
Value: 0,
Usage: "Duration before an inactive connection is timed out (0 to disable)",
},
},
2018-01-01 17:41:21 +08:00
}, {
Name: "healthcheck",
Action: func(c *cli.Context) error { return healthcheck(c.String("addr"), c.Bool("wait"), c.Bool("quiet")) },
2018-01-01 17:41:21 +08:00
Flags: []cli.Flag{
cli.StringFlag{
Name: "addr, a",
Value: "localhost:2222",
Usage: "sshportal server address",
},
cli.BoolFlag{
Name: "wait, w",
Usage: "Loop indefinitely until sshportal is ready",
},
cli.BoolFlag{
Name: "quiet, q",
Usage: "Do not print errors, if any",
2018-01-01 17:41:21 +08:00
},
},
2018-01-02 17:57:18 +08:00
}, {
Name: "_test_server",
Hidden: true,
Action: testServer,
2017-11-24 21:29:41 +08:00
},
2017-09-30 19:12:43 +08:00
}
if err := app.Run(os.Args); err != nil {
log.Fatalf("error: %v", err)
}
2017-09-30 19:12:43 +08:00
}
2018-11-16 17:24:18 +08:00
var defaultChannelHandler ssh.ChannelHandler
func server(c *configServe) (err error) {
var db = (*gorm.DB)(nil)
// try to setup the local DB
if db, err = gorm.Open(c.dbDriver, c.dbURL); err != nil {
return
2017-10-30 23:48:14 +08:00
}
2017-12-04 01:18:17 +08:00
defer func() {
origErr := err
err = db.Close()
if origErr != nil {
err = origErr
2017-12-04 01:18:17 +08:00
}
}()
2017-11-19 08:18:17 +08:00
if err = db.DB().Ping(); err != nil {
return
2017-10-31 00:12:04 +08:00
}
db.LogMode(c.debug)
if err = dbInit(db); err != nil {
return
}
2018-01-05 18:02:13 +08:00
// create TCP listening socket
ln, err := net.Listen("tcp", c.bindAddr)
if err != nil {
return err
}
// configure server
srv := &ssh.Server{
2018-11-16 17:24:18 +08:00
Addr: c.bindAddr,
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", Version),
}
2018-11-16 17:24:18 +08:00
// configure channel handler
defaultSessionHandler := srv.GetChannelHandler("session")
defaultDirectTcpipHandler := srv.GetChannelHandler("direct-tcpip")
defaultChannelHandler = func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
switch newChan.ChannelType() {
case "session":
go defaultSessionHandler(srv, conn, newChan, ctx)
case "direct-tcpip":
go defaultDirectTcpipHandler(srv, conn, newChan, ctx)
default:
2018-11-16 22:44:29 +08:00
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
log.Printf("failed to reject chan: %v", err)
}
2018-11-16 17:24:18 +08:00
}
}
srv.SetChannelHandler("session", nil)
srv.SetChannelHandler("direct-tcpip", nil)
srv.SetChannelHandler("default", channelHandler)
2018-11-16 02:38:18 +08:00
if c.idleTimeout != 0 {
srv.IdleTimeout = c.idleTimeout
// gliderlabs/ssh requires MaxTimeout to be non-zero if we want to use IdleTimeout.
// So, set it to the max value, because we don't want a max timeout.
srv.MaxTimeout = math.MaxInt64
}
for _, opt := range []ssh.Option{
// custom PublicKeyAuth handler
ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)),
ssh.PasswordAuth(passwordAuthHandler(db, c)),
// retrieve sshportal SSH private key from database
privateKeyFromDB(db, c.aesKey),
} {
if err := srv.SetOption(opt); err != nil {
return err
}
}
2018-11-16 02:56:10 +08:00
log.Printf("info: SSH Server accepting connections on %s, idle-timout=%v", c.bindAddr, c.idleTimeout)
return srv.Serve(ln)
2017-09-30 19:12:43 +08:00
}