sshportal/main.go

368 lines
9.4 KiB
Go
Raw Normal View History

2017-09-30 19:12:43 +08:00
package main
import (
2018-01-01 17:41:21 +08:00
"bytes"
2017-11-08 02:40:14 +08:00
"errors"
2017-09-30 19:12:43 +08:00
"fmt"
"log"
2017-11-14 08:29:25 +08:00
"math/rand"
2018-01-01 17:41:21 +08:00
"net"
2017-09-30 19:12:43 +08:00
"os"
"path"
2017-11-08 02:40:14 +08:00
"strings"
2017-11-14 08:29:25 +08:00
"time"
2017-09-30 19:12:43 +08:00
"github.com/gliderlabs/ssh"
2017-10-30 23:48:14 +08:00
"github.com/jinzhu/gorm"
2017-10-31 00:12:04 +08:00
_ "github.com/jinzhu/gorm/dialects/mysql"
2017-10-30 23:48:14 +08:00
_ "github.com/jinzhu/gorm/dialects/sqlite"
2017-09-30 19:12:43 +08:00
"github.com/urfave/cli"
2017-11-04 05:54:16 +08:00
gossh "golang.org/x/crypto/ssh"
2017-09-30 19:12:43 +08:00
)
2017-11-14 07:38:23 +08:00
var (
2017-12-04 01:18:17 +08:00
// Version should be updated by hand at each release
2017-12-12 17:40:14 +08:00
Version = "1.6.0+dev"
2017-12-04 01:18:17 +08:00
// GitTag will be overwritten automatically by the build system
GitTag string
// GitSha will be overwritten automatically by the build system
GitSha string
// GitBranch will be overwritten automatically by the build system
GitBranch string
2017-11-14 07:38:23 +08:00
)
2017-11-02 05:09:08 +08:00
type sshportalContextKey string
2017-11-08 02:40:14 +08:00
var (
userContextKey = sshportalContextKey("user")
messageContextKey = sshportalContextKey("message")
errorContextKey = sshportalContextKey("error")
)
2017-09-30 19:12:43 +08:00
func main() {
2017-11-14 08:29:25 +08:00
rand.Seed(time.Now().UnixNano())
2017-09-30 19:12:43 +08:00
app := cli.NewApp()
app.Name = path.Base(os.Args[0])
app.Author = "Manfred Touron"
2017-12-04 01:18:17 +08:00
app.Version = Version + " (" + GitSha + ")"
2017-09-30 19:12:43 +08:00
app.Email = "https://github.com/moul/sshportal"
app.Commands = []cli.Command{
{
Name: "server",
Usage: "Start sshportal server",
Action: server,
Flags: []cli.Flag{
cli.StringFlag{
Name: "bind-address, b",
EnvVar: "SSHPORTAL_BIND",
Value: ":2222",
Usage: "SSH server bind address",
},
cli.StringFlag{
Name: "db-driver",
Value: "sqlite3",
Usage: "GORM driver (sqlite3)",
},
cli.StringFlag{
Name: "db-conn",
Value: "./sshportal.db",
Usage: "GORM connection string",
},
cli.BoolFlag{
Name: "debug, D",
Usage: "Display debug information",
},
cli.StringFlag{
Name: "config-user",
Usage: "SSH user that spawns a configuration shell",
Value: "admin",
},
cli.StringFlag{
Name: "healthcheck-user",
Usage: "SSH user that returns healthcheck status without checking the SSH key",
Value: "healthcheck",
},
cli.StringFlag{
Name: "aes-key",
Usage: "Encrypt sensitive data in database (length: 16, 24 or 32)",
},
},
2018-01-01 17:41:21 +08:00
}, {
Name: "healthcheck",
Action: healthcheck,
Flags: []cli.Flag{
cli.StringFlag{
Name: "addr, a",
Value: "localhost:2222",
Usage: "sshportal server address",
},
cli.BoolFlag{
Name: "wait, w",
Usage: "Loop indefinitely until sshportal is ready",
},
cli.BoolFlag{
Name: "quiet, q",
Usage: "Do not print errors, if any",
2018-01-01 17:41:21 +08:00
},
},
2017-11-24 21:29:41 +08:00
},
2017-09-30 19:12:43 +08:00
}
if err := app.Run(os.Args); err != nil {
log.Fatalf("error: %v", err)
}
2017-09-30 19:12:43 +08:00
}
func server(c *cli.Context) error {
2017-11-24 21:29:41 +08:00
switch len(c.String("aes-key")) {
case 0, 16, 24, 32:
default:
return fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
}
2017-11-19 08:18:17 +08:00
// db
2017-11-29 17:28:33 +08:00
db, err := gorm.Open(c.String("db-driver"), c.String("db-conn"))
2017-10-30 23:48:14 +08:00
if err != nil {
return err
}
2017-12-04 01:18:17 +08:00
defer func() {
if err2 := db.Close(); err2 != nil {
panic(err2)
}
}()
2017-11-19 08:18:17 +08:00
if err = db.DB().Ping(); err != nil {
return err
}
2017-10-31 00:12:04 +08:00
if c.Bool("debug") {
db.LogMode(true)
}
2017-10-30 23:48:14 +08:00
if err := dbInit(db); err != nil {
return err
}
2017-11-19 08:18:17 +08:00
// ssh server
2017-09-30 19:12:43 +08:00
ssh.Handle(func(s ssh.Session) {
currentUser := s.Context().Value(userContextKey).(User)
2018-01-01 17:41:21 +08:00
if s.User() != "healthcheck" {
log.Printf("New connection: sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
}
2017-11-08 02:40:14 +08:00
if err := s.Context().Value(errorContextKey); err != nil {
fmt.Fprintf(s, "error: %v\n", err)
return
}
2017-09-30 19:12:43 +08:00
2017-11-08 02:40:14 +08:00
if msg := s.Context().Value(messageContextKey); msg != nil {
fmt.Fprint(s, msg.(string))
}
switch username := s.User(); {
2017-12-04 16:34:52 +08:00
case username == c.String("healthcheck-user"):
fmt.Fprintln(s, "OK")
return
case username == currentUser.Name || username == currentUser.Email || username == c.String("config-user"):
2017-11-02 00:00:34 +08:00
if err := shell(c, s, s.Command(), db); err != nil {
fmt.Fprintf(s, "error: %v\n", err)
2017-10-31 17:17:06 +08:00
}
2017-11-08 02:40:14 +08:00
case strings.HasPrefix(username, "invite:"):
return
2017-09-30 19:12:43 +08:00
default:
2017-10-31 16:24:18 +08:00
host, err := RemoteHostFromSession(s, db)
2017-09-30 19:12:43 +08:00
if err != nil {
fmt.Fprintf(s, "error: %v\n", err)
2017-10-31 16:24:18 +08:00
// FIXME: print available hosts
2017-09-30 19:12:43 +08:00
return
}
2017-11-13 17:13:17 +08:00
// load up-to-date objects
// FIXME: cache them or try not to load them
var tmpUser User
2017-12-04 01:18:17 +08:00
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", currentUser.ID).First(&tmpUser).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
2017-11-13 17:13:17 +08:00
return
}
var tmpHost Host
2017-12-04 01:18:17 +08:00
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
2017-11-13 17:13:17 +08:00
return
2017-09-30 19:12:43 +08:00
}
2017-11-13 17:13:17 +08:00
2017-12-04 01:18:17 +08:00
action, err2 := CheckACLs(tmpUser, tmpHost)
if err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
2017-11-13 17:13:17 +08:00
return
}
2017-11-24 21:29:41 +08:00
// decrypt key and password
HostDecrypt(c.String("aes-key"), host)
SSHKeyDecrypt(c.String("aes-key"), host.SSHKey)
2017-11-13 17:13:17 +08:00
switch action {
2017-12-04 01:18:17 +08:00
case ACLActionAllow:
sess := Session{
UserID: currentUser.ID,
HostID: host.ID,
Status: SessionStatusActive,
}
2017-12-04 01:18:17 +08:00
if err2 := db.Create(&sess).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
sessUpdate := Session{}
2017-12-04 01:18:17 +08:00
if err2 := proxy(s, host, DynamicHostKey(db, host)); err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
sessUpdate.ErrMsg = fmt.Sprintf("%v", err2)
switch sessUpdate.ErrMsg {
case "lch closed the connection", "rch closed the connection":
sessUpdate.ErrMsg = ""
}
2017-11-13 17:13:17 +08:00
}
sessUpdate.Status = SessionStatusClosed
now := time.Now()
sessUpdate.StoppedAt = &now
db.Model(&sess).Updates(&sessUpdate)
2017-12-04 01:18:17 +08:00
case ACLActionDeny:
2017-11-13 17:13:17 +08:00
fmt.Fprintf(s, "You don't have permission to that host.\n")
default:
2017-12-04 01:18:17 +08:00
fmt.Fprintf(s, "error: invalid ACL action: %q\n", action)
2017-11-13 17:13:17 +08:00
}
2017-09-30 19:12:43 +08:00
}
})
opts := []ssh.Option{}
2018-01-01 17:41:21 +08:00
opts = append(opts, ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
ctx.SetValue(userContextKey, User{})
return ctx.User() == "healthcheck"
}))
opts = append(opts, ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
var (
2017-11-08 02:40:14 +08:00
userKey UserKey
user User
username = ctx.User()
)
// lookup user by key
2017-11-29 17:28:33 +08:00
db.Where("authorized_key = ?", string(gossh.MarshalAuthorizedKey(key))).First(&userKey)
if userKey.UserID > 0 {
2017-11-23 23:22:23 +08:00
db.Preload("Roles").Where("id = ?", userKey.UserID).First(&user)
2017-11-08 02:40:14 +08:00
if strings.HasPrefix(username, "invite:") {
2017-12-21 22:35:26 +08:00
ctx.SetValue(errorContextKey, fmt.Errorf("invites are only supported for new SSH keys; your ssh key is already associated with the user %q", user.Email))
2017-11-08 02:40:14 +08:00
}
ctx.SetValue(userContextKey, user)
return true
}
2017-11-08 02:40:14 +08:00
// handle invite "links"
if strings.HasPrefix(username, "invite:") {
inputToken := strings.Split(username, ":")[1]
if len(inputToken) > 0 {
2017-11-08 02:40:14 +08:00
db.Where("invite_token = ?", inputToken).First(&user)
}
2017-11-08 02:40:14 +08:00
if user.ID > 0 {
userKey = UserKey{
2017-11-29 17:28:33 +08:00
UserID: user.ID,
Key: key.Marshal(),
Comment: "created by sshportal",
AuthorizedKey: string(gossh.MarshalAuthorizedKey(key)),
2017-11-08 02:40:14 +08:00
}
db.Create(&userKey)
// token is only usable once
user.InviteToken = ""
2017-11-29 17:28:33 +08:00
db.Model(&user).Updates(&user)
2017-11-08 02:40:14 +08:00
ctx.SetValue(messageContextKey, fmt.Sprintf("Welcome %s!\n\nYour key is now associated with the user %q.\n", user.Name, user.Email))
ctx.SetValue(userContextKey, user)
} else {
ctx.SetValue(userContextKey, User{Name: "Anonymous"})
ctx.SetValue(errorContextKey, errors.New("your token is invalid or expired"))
}
return true
}
2017-11-08 02:40:14 +08:00
// fallback
ctx.SetValue(errorContextKey, errors.New("unknown ssh key"))
ctx.SetValue(userContextKey, User{Name: "Anonymous"})
return true
}))
2017-11-04 05:54:16 +08:00
opts = append(opts, func(srv *ssh.Server) error {
2017-11-23 16:58:32 +08:00
var key SSHKey
if err := SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
2017-11-04 05:54:16 +08:00
return err
}
2017-11-29 17:28:33 +08:00
SSHKeyDecrypt(c.String("aes-key"), &key)
2017-11-04 05:54:16 +08:00
signer, err := gossh.ParsePrivateKey([]byte(key.PrivKey))
if err != nil {
return err
}
srv.AddHostKey(signer)
return nil
})
2017-12-04 18:13:36 +08:00
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
2017-09-30 19:12:43 +08:00
return ssh.ListenAndServe(c.String("bind-address"), nil, opts...)
}
2018-01-01 17:41:21 +08:00
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
func healthcheck(c *cli.Context) error {
config := gossh.ClientConfig{
User: "healthcheck",
HostKeyCallback: func(hostname string, remote net.Addr, key gossh.PublicKey) error { return nil },
Auth: []gossh.AuthMethod{gossh.Password("healthcheck")},
}
if c.Bool("wait") {
for {
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if !c.Bool("quiet") {
log.Printf("error: %v", err)
}
time.Sleep(time.Second)
continue
}
return nil
}
}
if err := healthcheckOnce(c.String("addr"), config, c.Bool("quiet")); err != nil {
if c.Bool("quiet") {
return cli.NewExitError("", 1)
}
return err
}
return nil
}
func healthcheckOnce(addr string, config gossh.ClientConfig, quiet bool) error {
client, err := gossh.Dial("tcp", addr, &config)
2018-01-01 17:41:21 +08:00
if err != nil {
return err
}
session, err := client.NewSession()
if err != nil {
return err
}
defer func() {
if err := session.Close(); err != nil {
if !quiet {
log.Printf("failed to close session: %v", err)
}
2018-01-01 17:41:21 +08:00
}
}()
var b bytes.Buffer
session.Stdout = &b
if err := session.Run(""); err != nil {
return err
}
stdout := strings.TrimSpace(b.String())
if stdout != "OK" {
return fmt.Errorf("invalid stdout: %q expected 'OK'", stdout)
}
return nil
}