Refactor bastion handler to forward every requests properly

This commit is contained in:
Manfred Touron 2017-12-31 10:38:33 +01:00
parent 072464928b
commit d6bb5e44a1
8 changed files with 443 additions and 694 deletions

View file

@ -4,6 +4,7 @@
Breaking changes:
* Use `sshportal server` instead of `sshportal` to start a new server (nothing to change if using the docker image)
* Remove `--config-user` and `--healthcheck-user` global options
Changes:
* Fix connection failure when sending too many environment variables (fix [#22](https://github.com/moul/sshportal/issues/22))

51
db.go
View file

@ -10,8 +10,8 @@ import (
"time"
"github.com/asaskevich/govalidator"
"github.com/gliderlabs/ssh"
"github.com/jinzhu/gorm"
gossh "golang.org/x/crypto/ssh"
)
type Config struct {
@ -168,16 +168,6 @@ func init() {
}))
}
func RemoteHostFromSession(s ssh.Session, db *gorm.DB) (*Host, error) {
var host Host
db.Preload("SSHKey").Where("name = ?", s.User()).Find(&host)
if host.Name == "" {
// FIXME: add available hosts
return nil, fmt.Errorf("No such target: %q", s.User())
}
return &host, nil
}
func (host *Host) URL() string {
return fmt.Sprintf("%s@%s", host.User, host.Addr)
}
@ -215,6 +205,37 @@ func HostsPreload(db *gorm.DB) *gorm.DB {
func HostsByIdentifiers(db *gorm.DB, identifiers []string) *gorm.DB {
return db.Where("id IN (?)", identifiers).Or("name IN (?)", identifiers)
}
func HostByName(db *gorm.DB, name string) (*Host, error) {
var host Host
db.Preload("SSHKey").Where("name = ?", name).Find(&host)
if host.Name == "" {
// FIXME: add available hosts
return nil, fmt.Errorf("No such target: %q", name)
}
return &host, nil
}
func (host *Host) clientConfig(hk gossh.HostKeyCallback) (*gossh.ClientConfig, error) {
config := gossh.ClientConfig{
User: host.User,
HostKeyCallback: hk,
Auth: []gossh.AuthMethod{},
}
if host.SSHKey != nil {
signer, err := gossh.ParsePrivateKey([]byte(host.SSHKey.PrivKey))
if err != nil {
return nil, err
}
config.Auth = append(config.Auth, gossh.PublicKeys(signer))
}
if host.Password != "" {
config.Auth = append(config.Auth, gossh.Password(host.Password))
}
if len(config.Auth) == 0 {
return nil, fmt.Errorf("no valid authentication method for host %q", host.Name)
}
return &config, nil
}
// SSHKey helpers
@ -251,18 +272,18 @@ func UsersPreload(db *gorm.DB) *gorm.DB {
func UsersByIdentifiers(db *gorm.DB, identifiers []string) *gorm.DB {
return db.Where("id IN (?)", identifiers).Or("email IN (?)", identifiers).Or("name IN (?)", identifiers)
}
func UserHasRole(user User, name string) bool {
for _, role := range user.Roles {
func (u *User) HasRole(name string) bool {
for _, role := range u.Roles {
if role.Name == name {
return true
}
}
return false
}
func UserCheckRoles(user User, names []string) error {
func (u *User) CheckRoles(names []string) error {
ok := false
for _, name := range names {
if UserHasRole(user, name) {
if u.HasRole(name) {
ok = true
break
}

206
main.go
View file

@ -2,7 +2,6 @@ package main
import (
"bytes"
"errors"
"fmt"
"log"
"math/rand"
@ -18,8 +17,6 @@ import (
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
"github.com/moul/sshportal/pkg/bastionsession"
)
var (
@ -33,14 +30,6 @@ var (
GitBranch string
)
type sshportalContextKey string
var (
userContextKey = sshportalContextKey("user")
messageContextKey = sshportalContextKey("message")
errorContextKey = sshportalContextKey("error")
)
func main() {
rand.Seed(time.Now().UnixNano())
@ -75,16 +64,6 @@ func main() {
Name: "debug, D",
Usage: "Display debug information",
},
cli.StringFlag{
Name: "config-user",
Usage: "SSH user that spawns a configuration shell",
Value: "admin",
},
cli.StringFlag{
Name: "healthcheck-user",
Usage: "SSH user that returns healthcheck status without checking the SSH key",
Value: "healthcheck",
},
cli.StringFlag{
Name: "aes-key",
Usage: "Encrypt sensitive data in database (length: 16, 24 or 32)",
@ -141,160 +120,12 @@ func server(c *cli.Context) error {
return err
}
// ssh server
shellHandler := func(s ssh.Session) {
currentUser := s.Context().Value(userContextKey).(User)
if s.User() != "healthcheck" {
log.Printf("New connection(shell): sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
}
if err2 := s.Context().Value(errorContextKey); err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
if msg := s.Context().Value(messageContextKey); msg != nil {
fmt.Fprint(s, msg.(string))
}
switch username := s.User(); {
case username == c.String("healthcheck-user"):
fmt.Fprintln(s, "OK")
return
case username == currentUser.Name || username == currentUser.Email || username == c.String("config-user"):
if err = shell(c, s, s.Command(), db); err != nil {
fmt.Fprintf(s, "error: %v\n", err)
}
case strings.HasPrefix(username, "invite:"):
return
default:
fmt.Fprintf(s, "error: invalid user\n")
}
}
bastionHandler := func(s ssh.Session) {
currentUser := s.Context().Value(userContextKey).(User)
log.Printf("New connection(bastion): sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
var host *Host
host, err = RemoteHostFromSession(s, db)
if err != nil {
fmt.Fprintf(s, "error: %v\n", err)
// FIXME: print available hosts
return
}
// load up-to-date objects
// FIXME: cache them or try not to load them
var tmpUser User
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", currentUser.ID).First(&tmpUser).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
var tmpHost Host
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
action, err2 := CheckACLs(tmpUser, tmpHost)
if err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
// decrypt key and password
HostDecrypt(c.String("aes-key"), host)
SSHKeyDecrypt(c.String("aes-key"), host.SSHKey)
switch action {
case ACLActionAllow:
sess := Session{
UserID: currentUser.ID,
HostID: host.ID,
Status: SessionStatusActive,
}
if err2 := db.Create(&sess).Error; err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
return
}
sessUpdate := Session{}
if err2 := proxy(s, host, DynamicHostKey(db, host)); err2 != nil {
fmt.Fprintf(s, "error: %v\n", err2)
sessUpdate.ErrMsg = fmt.Sprintf("%v", err2)
switch sessUpdate.ErrMsg {
case "lch closed the connection", "rch closed the connection":
sessUpdate.ErrMsg = ""
}
}
sessUpdate.Status = SessionStatusClosed
now := time.Now()
sessUpdate.StoppedAt = &now
db.Model(&sess).Updates(&sessUpdate)
case ACLActionDeny:
fmt.Fprintf(s, "You don't have permission to that host.\n")
default:
fmt.Fprintf(s, "error: invalid ACL action: %q\n", action)
}
}
opts := []ssh.Option{}
opts = append(opts, ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
ctx.SetValue(userContextKey, User{})
return ctx.User() == "healthcheck"
}))
opts = append(opts, ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
var (
userKey UserKey
user User
username = ctx.User()
)
// lookup user by key
db.Where("authorized_key = ?", string(gossh.MarshalAuthorizedKey(key))).First(&userKey)
if userKey.UserID > 0 {
db.Preload("Roles").Where("id = ?", userKey.UserID).First(&user)
if strings.HasPrefix(username, "invite:") {
ctx.SetValue(errorContextKey, fmt.Errorf("invites are only supported for new SSH keys; your ssh key is already associated with the user %q", user.Email))
}
ctx.SetValue(userContextKey, user)
return true
}
// handle invite "links"
if strings.HasPrefix(username, "invite:") {
inputToken := strings.Split(username, ":")[1]
if len(inputToken) > 0 {
db.Where("invite_token = ?", inputToken).First(&user)
}
if user.ID > 0 {
userKey = UserKey{
UserID: user.ID,
Key: key.Marshal(),
Comment: "created by sshportal",
AuthorizedKey: string(gossh.MarshalAuthorizedKey(key)),
}
db.Create(&userKey)
// token is only usable once
user.InviteToken = ""
db.Model(&user).Updates(&user)
ctx.SetValue(messageContextKey, fmt.Sprintf("Welcome %s!\n\nYour key is now associated with the user %q.\n", user.Name, user.Email))
ctx.SetValue(userContextKey, user)
} else {
ctx.SetValue(userContextKey, User{Name: "Anonymous"})
ctx.SetValue(errorContextKey, errors.New("your token is invalid or expired"))
}
return true
}
// fallback
ctx.SetValue(errorContextKey, errors.New("unknown ssh key"))
ctx.SetValue(userContextKey, User{Name: "Anonymous"})
return true
}))
// custom PublicKeyAuth handler
opts = append(opts, ssh.PublicKeyAuth(publicKeyAuthHandler(db, c)))
opts = append(opts, ssh.PasswordAuth(passwordAuthHandler(db, c)))
// retrieve sshportal SSH private key from databse
opts = append(opts, func(srv *ssh.Server) error {
var key SSHKey
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
@ -311,35 +142,26 @@ func server(c *cli.Context) error {
return nil
})
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
// create TCP listening socket
ln, err := net.Listen("tcp", c.String("bind-address"))
if err != nil {
return err
}
srv := &ssh.Server{Addr: c.String("bind-address"), Handler: shellHandler}
// configure server
srv := &ssh.Server{
Addr: c.String("bind-address"),
Handler: shellHandler, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", Version),
ChannelHandler: channelHandler,
}
for _, opt := range opts {
if err := srv.SetOption(opt); err != nil {
return err
}
}
srv.Version = fmt.Sprintf("sshportal-%s", Version)
srv.ChannelHandler = func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
if newChan.ChannelType() != "session" {
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
log.Printf("error: failed to reject channel: %v", err)
}
return
}
// TODO: handle direct-tcp
currentUser := ctx.Value(userContextKey).(User)
username := conn.User()
if username == c.String("healthcheck-user") || username == currentUser.Name || username == currentUser.Email || username == c.String("config-user") || strings.HasPrefix(username, "invite:") {
ssh.DefaultChannelHandler(srv, conn, newChan, ctx)
} else {
bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionHandler)
}
}
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
return srv.Serve(ln)
}

View file

@ -1,256 +1,89 @@
package bastionsession
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"sync"
"io"
"github.com/anmitsu/go-shlex"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
// maxSigBufSize is how many signals will be buffered
// when there is no signal channel specified
const maxSigBufSize = 128
type Config struct {
Addr string
ClientConfig *gossh.ClientConfig
}
func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, handler ssh.Handler) {
func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, config Config) error {
if newChan.ChannelType() != "session" {
newChan.Reject(gossh.UnknownChannelType, "unsupported channel type")
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
if handler == nil {
handler = srv.Handler
}
sess := &session{
Channel: ch,
conn: conn,
handler: handler,
ptyCb: srv.PtyCallback,
maskedReqs: make(chan *gossh.Request, 5),
ctx: ctx,
}
sess.ctx.SetValue("masked-reqs", sess.maskedReqs)
//ssh.DefaultChannelHandler(srv, conn, ch, ctx)
//return
sess.handleRequests(reqs)
}
type session struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
handler ssh.Handler
handled bool
exited bool
pty *ssh.Pty
winch chan ssh.Window
env []string
ptyCb ssh.PtyCallback
cmd []string
ctx ssh.Context
sigCh chan<- ssh.Signal
sigBuf []ssh.Signal
maskedReqs chan *gossh.Request
}
func (sess *session) Write(p []byte) (n int, err error) {
if sess.pty != nil {
m := len(p)
// normalize \n to \r\n when pty is accepted.
// this is a hardcoded shortcut since we don't support terminal modes.
p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
n, err = sess.Channel.Write(p)
if n > m {
n = m
}
return
}
return sess.Channel.Write(p)
}
func (sess *session) PublicKey() ssh.PublicKey {
sessionkey := sess.ctx.Value(ssh.ContextKeyPublicKey)
if sessionkey == nil {
return nil
}
return sessionkey.(ssh.PublicKey)
}
func (sess *session) Permissions() ssh.Permissions {
// use context permissions because its properly
// wrapped and easier to dereference
perms := sess.ctx.Value(ssh.ContextKeyPermissions).(*ssh.Permissions)
return *perms
}
func (sess *session) Context() context.Context {
return sess.ctx
}
func (sess *session) Exit(code int) error {
sess.Lock()
defer sess.Unlock()
if sess.exited {
return errors.New("Session.Exit called multiple times")
lch, lreqs, err := newChan.Accept()
// TODO: defer clean closer
if err != nil {
// TODO: trigger event callback
return nil
}
sess.exited = true
status := struct{ Status uint32 }{uint32(code)}
_, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status))
// open client channel
rconn, err := gossh.Dial("tcp", config.Addr, config.ClientConfig)
if err != nil {
return err
}
defer func() { _ = rconn.Close() }()
rch, rreqs, err := rconn.OpenChannel("session", []byte{})
if err != nil {
return err
}
close(sess.maskedReqs)
return sess.Close()
// pipe everything
return pipe(lreqs, rreqs, lch, rch)
}
func (sess *session) User() string {
return sess.conn.User()
}
func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel) error {
defer func() {
_ = lch.Close()
_ = rch.Close()
}()
func (sess *session) RemoteAddr() net.Addr {
return sess.conn.RemoteAddr()
}
errch := make(chan error, 1)
func (sess *session) LocalAddr() net.Addr {
return sess.conn.LocalAddr()
}
go func() {
_, _ = io.Copy(lch, rch)
errch <- errors.New("lch closed the connection")
}()
func (sess *session) Environ() []string {
return append([]string(nil), sess.env...)
}
go func() {
_, _ = io.Copy(rch, lch)
errch <- errors.New("rch closed the connection")
}()
func (sess *session) Command() []string {
return append([]string(nil), sess.cmd...)
}
func (sess *session) Pty() (ssh.Pty, <-chan ssh.Window, bool) {
if sess.pty != nil {
return *sess.pty, sess.winch, true
}
return ssh.Pty{}, sess.winch, false
}
func (sess *session) MaskedReqs() chan *gossh.Request {
return sess.maskedReqs
}
func (sess *session) Signals(c chan<- ssh.Signal) {
sess.Lock()
defer sess.Unlock()
sess.sigCh = c
if len(sess.sigBuf) > 0 {
go func() {
for _, sig := range sess.sigBuf {
sess.sigCh <- sig
for {
select {
case req := <-lreqs: // forward ssh requests from local to remote
if req == nil {
return nil
}
}()
}
}
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
for req := range reqs {
addToMaskedReqs := true
switch req.Type {
case "shell", "exec":
if sess.handled {
req.Reply(false, nil)
addToMaskedReqs = false
continue
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
sess.handled = true
// req.Reply(true, nil) // let the proxy reply
var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
sess.cmd, _ = shlex.Split(payload.Value, true)
go func() {
sess.handler(sess)
sess.Exit(0)
}()
case "env":
if sess.handled {
req.Reply(false, nil)
continue
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
var kv struct{ Key, Value string }
gossh.Unmarshal(req.Payload, &kv)
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
req.Reply(true, nil)
case "signal":
var payload struct{ Signal string }
gossh.Unmarshal(req.Payload, &payload)
sess.Lock()
if sess.sigCh != nil {
sess.sigCh <- ssh.Signal(payload.Signal)
} else {
if len(sess.sigBuf) < maxSigBufSize {
sess.sigBuf = append(sess.sigBuf, ssh.Signal(payload.Signal))
}
case req := <-rreqs: // forward ssh requests from remote to local
if req == nil {
return nil
}
sess.Unlock()
case "pty-req":
if sess.handled || sess.pty != nil {
req.Reply(false, nil)
addToMaskedReqs = false
continue
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
ptyReq, ok := parsePtyRequest(req.Payload)
if !ok {
req.Reply(false, nil)
addToMaskedReqs = false
continue
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
if sess.ptyCb != nil {
ok := sess.ptyCb(sess.ctx, ptyReq)
if !ok {
req.Reply(false, nil)
addToMaskedReqs = false
continue
}
}
sess.pty = &ptyReq
sess.winch = make(chan ssh.Window, 1)
sess.winch <- ptyReq.Window
defer func() {
// when reqs is closed
close(sess.winch)
}()
//req.Reply(ok, nil) // let the proxy reply
case "window-change":
if sess.pty == nil {
req.Reply(false, nil)
continue
}
win, ok := parseWinchRequest(req.Payload)
if ok {
sess.pty.Window = win
sess.winch <- win
}
req.Reply(ok, nil)
case "auth-agent-req@openssh.com":
// TODO: option/callback to allow agent forwarding
ssh.SetAgentRequested(sess.ctx)
req.Reply(true, nil)
default:
// TODO: debug log
}
if addToMaskedReqs {
sess.maskedReqs <- req
case err := <-errch:
return err
}
}
}

View file

@ -1,79 +0,0 @@
package bastionsession
import (
"encoding/binary"
"github.com/gliderlabs/ssh"
)
func parsePtyRequest(s []byte) (pty ssh.Pty, ok bool) {
term, s, ok := parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if width32 < 1 {
ok = false
}
if !ok {
return
}
height32, _, ok := parseUint32(s)
if height32 < 1 {
ok = false
}
if !ok {
return
}
pty = ssh.Pty{
Term: term,
Window: ssh.Window{
Width: int(width32),
Height: int(height32),
},
}
return
}
func parseWinchRequest(s []byte) (win ssh.Window, ok bool) {
width32, s, ok := parseUint32(s)
if width32 < 1 {
ok = false
}
if !ok {
return
}
height32, _, ok := parseUint32(s)
if height32 < 1 {
ok = false
}
if !ok {
return
}
win = ssh.Window{
Width: int(width32),
Height: int(height32),
}
return
}
func parseString(in []byte) (out string, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
return
}
out = string(in[4 : 4+length])
rest = in[4+length:]
ok = true
return
}
func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
return 0, nil, false
}
return binary.BigEndian.Uint32(in), in[4:], true
}

106
proxy.go
View file

@ -1,106 +0,0 @@
package main
import (
"errors"
"fmt"
"io"
"log"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
func proxy(s ssh.Session, host *Host, hk gossh.HostKeyCallback) error {
config, err := host.clientConfig(s, hk)
if err != nil {
return err
}
rconn, err := gossh.Dial("tcp", host.Addr, config)
if err != nil {
return err
}
defer func() { _ = rconn.Close() }()
rch, rreqs, err := rconn.OpenChannel("session", []byte{})
if err != nil {
return err
}
log.Println("SSH Connection established")
maskedReqs := s.Context().Value("masked-reqs")
if maskedReqs == nil {
return fmt.Errorf("ctx.maskedReqs doesn't exist")
}
return pipe(maskedReqs.(chan *gossh.Request), rreqs, s, rch)
}
func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel) error {
defer func() {
_ = lch.Close()
_ = rch.Close()
}()
errch := make(chan error, 1)
go func() {
_, _ = io.Copy(lch, rch)
errch <- errors.New("lch closed the connection")
}()
go func() {
_, _ = io.Copy(rch, lch)
errch <- errors.New("rch closed the connection")
}()
for {
select {
case req := <-lreqs: // forward ssh requests from local to remote
if req == nil {
return nil
}
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
case req := <-rreqs: // forward ssh requests from remote to local
if req == nil {
return nil
}
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
case err := <-errch:
return err
}
}
}
func (host *Host) clientConfig(_ ssh.Session, hk gossh.HostKeyCallback) (*gossh.ClientConfig, error) {
config := gossh.ClientConfig{
User: host.User,
HostKeyCallback: hk,
Auth: []gossh.AuthMethod{},
}
if host.SSHKey != nil {
signer, err := gossh.ParsePrivateKey([]byte(host.SSHKey.PrivKey))
if err != nil {
return nil, err
}
config.Auth = append(config.Auth, gossh.PublicKeys(signer))
}
if host.Password != "" {
config.Auth = append(config.Auth, gossh.Password(host.Password))
}
if len(config.Auth) == 0 {
return nil, fmt.Errorf("no valid authentication method for host %q", host.Name)
}
return &config, nil
}

137
shell.go
View file

@ -14,7 +14,6 @@ import (
"github.com/asaskevich/govalidator"
humanize "github.com/dustin/go-humanize"
"github.com/gliderlabs/ssh"
"github.com/jinzhu/gorm"
"github.com/mgutz/ansi"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/olekukonko/tablewriter"
@ -33,7 +32,11 @@ var banner = `
`
var startTime = time.Now()
func shell(globalContext *cli.Context, s ssh.Session, sshCommand []string, db *gorm.DB) error {
func shell(s ssh.Session) error {
var (
sshCommand = s.Command()
actx = s.Context().Value(authContextKey).(*authContext)
)
if len(sshCommand) == 0 {
if _, err := fmt.Fprint(s, banner); err != nil {
return err
@ -55,7 +58,11 @@ GLOBAL OPTIONS:
app.Writer = s
app.HideVersion = true
myself := s.Context().Value(userContextKey).(User)
var (
myself = &actx.user
db = actx.db
)
app.Commands = []cli.Command{
{
Name: "acl",
@ -74,7 +81,7 @@ GLOBAL OPTIONS:
cli.UintFlag{Name: "weight, w", Usage: "Assigns the ACL weight (priority)"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
acl := ACL{
@ -124,7 +131,7 @@ GLOBAL OPTIONS:
if c.NArg() < 1 {
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -145,7 +152,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -206,7 +213,7 @@ GLOBAL OPTIONS:
if c.NArg() < 1 {
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -230,7 +237,7 @@ GLOBAL OPTIONS:
if c.NArg() < 1 {
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -303,7 +310,7 @@ GLOBAL OPTIONS:
},
Description: "ssh admin@portal config backup > sshportal.bkp",
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -316,11 +323,11 @@ GLOBAL OPTIONS:
return err
}
for _, key := range config.SSHKeys {
SSHKeyDecrypt(globalContext.String("aes-key"), key)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
}
if !c.Bool("decrypt") {
for _, key := range config.SSHKeys {
if err := SSHKeyEncrypt(globalContext.String("aes-key"), key); err != nil {
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err != nil {
return err
}
}
@ -330,11 +337,11 @@ GLOBAL OPTIONS:
return err
}
for _, host := range config.Hosts {
HostDecrypt(globalContext.String("aes-key"), host)
HostDecrypt(actx.globalContext.String("aes-key"), host)
}
if !c.Bool("decrypt") {
for _, host := range config.Hosts {
if err := HostEncrypt(globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
return err
}
}
@ -382,7 +389,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "decrypt", Usage: "do not encrypt sensitive data"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -449,9 +456,9 @@ GLOBAL OPTIONS:
}
}
for _, host := range config.Hosts {
HostDecrypt(globalContext.String("aes-key"), host)
HostDecrypt(actx.globalContext.String("aes-key"), host)
if !c.Bool("decrypt") {
if err := HostEncrypt(globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
return err
}
}
@ -485,9 +492,9 @@ GLOBAL OPTIONS:
}
}
for _, sshKey := range config.SSHKeys {
SSHKeyDecrypt(globalContext.String("aes-key"), sshKey)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), sshKey)
if !c.Bool("decrypt") {
if err := SSHKeyEncrypt(globalContext.String("aes-key"), sshKey); err != nil {
if err := SSHKeyEncrypt(actx.globalContext.String("aes-key"), sshKey); err != nil {
return err
}
}
@ -543,7 +550,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -572,7 +579,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -642,7 +649,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -687,7 +694,7 @@ GLOBAL OPTIONS:
}
// encrypt
if err := HostEncrypt(globalContext.String("aes-key"), host); err != nil {
if err := HostEncrypt(actx.globalContext.String("aes-key"), host); err != nil {
return err
}
@ -709,13 +716,13 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin", "listhosts"}); err != nil {
if err := myself.CheckRoles([]string{"admin", "listhosts"}); err != nil {
return err
}
var hosts []*Host
db = db.Preload("Groups")
if UserHasRole(myself, "admin") {
if myself.HasRole("admin") {
db = db.Preload("SSHKey")
}
if err := HostsByIdentifiers(db, c.Args()).Find(&hosts).Error; err != nil {
@ -724,7 +731,7 @@ GLOBAL OPTIONS:
if c.Bool("decrypt") {
for _, host := range hosts {
HostDecrypt(globalContext.String("aes-key"), host)
HostDecrypt(actx.globalContext.String("aes-key"), host)
}
}
@ -740,7 +747,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin", "listhosts"}); err != nil {
if err := myself.CheckRoles([]string{"admin", "listhosts"}); err != nil {
return err
}
@ -808,7 +815,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -831,7 +838,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -902,7 +909,7 @@ GLOBAL OPTIONS:
cli.StringFlag{Name: "comment", Usage: "Adds a comment"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -933,7 +940,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -954,7 +961,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1007,7 +1014,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1019,18 +1026,18 @@ GLOBAL OPTIONS:
Name: "info",
Usage: "Shows system-wide information",
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
fmt.Fprintf(s, "Debug mode (server): %v\n", globalContext.Bool("debug"))
fmt.Fprintf(s, "Debug mode (server): %v\n", actx.globalContext.Bool("debug"))
hostname, _ := os.Hostname()
fmt.Fprintf(s, "Hostname: %s\n", hostname)
fmt.Fprintf(s, "CPUs: %d\n", runtime.NumCPU())
fmt.Fprintf(s, "Demo mode: %v\n", globalContext.Bool("demo"))
fmt.Fprintf(s, "DB Driver: %s\n", globalContext.String("db-driver"))
fmt.Fprintf(s, "DB Conn: %s\n", globalContext.String("db-conn"))
fmt.Fprintf(s, "Bind Address: %s\n", globalContext.String("bind-address"))
fmt.Fprintf(s, "Demo mode: %v\n", actx.globalContext.Bool("demo"))
fmt.Fprintf(s, "DB Driver: %s\n", actx.globalContext.String("db-driver"))
fmt.Fprintf(s, "DB Conn: %s\n", actx.globalContext.String("db-conn"))
fmt.Fprintf(s, "Bind Address: %s\n", actx.globalContext.String("bind-address"))
fmt.Fprintf(s, "System Time: %v\n", time.Now().Format(time.RFC3339Nano))
fmt.Fprintf(s, "OS Type: %s\n", runtime.GOOS)
fmt.Fprintf(s, "OS Architecture: %s\n", runtime.GOARCH)
@ -1038,7 +1045,7 @@ GLOBAL OPTIONS:
fmt.Fprintf(s, "Go version (build): %v\n", runtime.Version())
fmt.Fprintf(s, "Uptime: %v\n", time.Since(startTime))
fmt.Fprintf(s, "User email: %v\n", myself.ID)
fmt.Fprintf(s, "User ID: %v\n", myself.ID)
fmt.Fprintf(s, "User email: %s\n", myself.Email)
fmt.Fprintf(s, "Version: %s\n", Version)
fmt.Fprintf(s, "GIT SHA: %s\n", GitSha)
@ -1066,7 +1073,7 @@ GLOBAL OPTIONS:
cli.StringFlag{Name: "comment", Usage: "Adds a comment"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1076,8 +1083,8 @@ GLOBAL OPTIONS:
}
key, err := NewSSHKey(c.String("type"), c.Uint("length"))
if globalContext.String("aes-key") != "" {
if err2 := SSHKeyEncrypt(globalContext.String("aes-key"), key); err2 != nil {
if actx.globalContext.String("aes-key") != "" {
if err2 := SSHKeyEncrypt(actx.globalContext.String("aes-key"), key); err2 != nil {
return err2
}
}
@ -1111,7 +1118,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1122,7 +1129,7 @@ GLOBAL OPTIONS:
if c.Bool("decrypt") {
for _, key := range keys {
SSHKeyDecrypt(globalContext.String("aes-key"), key)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), key)
}
}
@ -1138,7 +1145,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1192,7 +1199,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1231,7 +1238,7 @@ GLOBAL OPTIONS:
if err := SSHKeysByIdentifiers(SSHKeysPreload(db), c.Args()).First(&key).Error; err != nil {
return err
}
SSHKeyDecrypt(globalContext.String("aes-key"), &key)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), &key)
type line struct {
key string
@ -1305,7 +1312,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1333,7 +1340,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1380,7 +1387,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1441,7 +1448,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1464,7 +1471,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1544,7 +1551,7 @@ GLOBAL OPTIONS:
cli.StringFlag{Name: "comment", Usage: "Adds a comment"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1562,7 +1569,7 @@ GLOBAL OPTIONS:
// FIXME: check if name already exists
// FIXME: add myself to the new group
userGroup.Users = []*User{&myself}
userGroup.Users = []*User{myself}
if err := db.Create(&userGroup).Error; err != nil {
return err
@ -1579,7 +1586,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1600,7 +1607,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1652,7 +1659,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1677,7 +1684,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1724,7 +1731,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1745,7 +1752,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1795,7 +1802,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1816,7 +1823,7 @@ GLOBAL OPTIONS:
return cli.ShowSubcommandHelp(c)
}
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1837,7 +1844,7 @@ GLOBAL OPTIONS:
cli.BoolFlag{Name: "quiet, q", Usage: "Only display IDs"},
},
Action: func(c *cli.Context) error {
if err := UserCheckRoles(myself, []string{"admin"}); err != nil {
if err := myself.CheckRoles([]string{"admin"}); err != nil {
return err
}
@ -1924,7 +1931,7 @@ GLOBAL OPTIONS:
if len(words) == 0 {
continue
}
NewEvent("shell", words[0]).SetAuthor(&myself).SetArg("interactive", true).SetArg("args", words[1:]).Log(db)
NewEvent("shell", words[0]).SetAuthor(myself).SetArg("interactive", true).SetArg("args", words[1:]).Log(db)
if err := app.Run(append([]string{"config"}, words...)); err != nil {
if cliErr, ok := err.(*cli.ExitError); ok {
if cliErr.ExitCode() != 0 {
@ -1937,7 +1944,7 @@ GLOBAL OPTIONS:
}
}
} else { // oneshot mode
NewEvent("shell", sshCommand[0]).SetAuthor(&myself).SetArg("interactive", false).SetArg("args", sshCommand[1:]).Log(db)
NewEvent("shell", sshCommand[0]).SetAuthor(myself).SetArg("interactive", false).SetArg("args", sshCommand[1:]).Log(db)
if err := app.Run(append([]string{"config"}, sshCommand...)); err != nil {
if errMsg := err.Error(); errMsg != "" {
fmt.Fprintf(s, "error: %s\n", errMsg)

282
ssh.go
View file

@ -2,35 +2,285 @@ package main
import (
"bytes"
"errors"
"fmt"
"log"
"net"
"strings"
"time"
"github.com/gliderlabs/ssh"
"github.com/jinzhu/gorm"
"github.com/moul/sshportal/pkg/bastionsession"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
)
type dynamicHostKey struct {
db *gorm.DB
host *Host
type sshportalContextKey string
var authContextKey = sshportalContextKey("auth")
type authContext struct {
message string
err error
user User
inputUsername string
db *gorm.DB
userKey UserKey
globalContext *cli.Context
authMethod string
authSuccess bool
}
func (d *dynamicHostKey) check(hostname string, remote net.Addr, key gossh.PublicKey) error {
if len(d.host.HostKey) == 0 {
log.Println("Discovering host fingerprint...")
return d.db.Model(d.host).Update("HostKey", key.Marshal()).Error
type UserType string
const (
UserTypeHealthcheck UserType = "healthcheck"
UserTypeBastion = "bastion"
UserTypeInvite = "invite"
UserTypeShell = "shell"
)
type SessionType string
const (
SessionTypeBastion SessionType = "bastion"
SessionTypeShell = "shell"
)
func (c authContext) userType() UserType {
switch {
case c.inputUsername == "healthcheck":
return UserTypeHealthcheck
case c.inputUsername == c.user.Name || c.inputUsername == c.user.Email || c.inputUsername == "admin":
return UserTypeShell
case strings.HasPrefix(c.inputUsername, "invite:"):
return UserTypeInvite
default:
return UserTypeBastion
}
}
func (c authContext) sessionType() SessionType {
switch c.userType() {
case "bastion":
return SessionTypeBastion
default:
return SessionTypeShell
}
}
func dynamicHostKey(db *gorm.DB, host *Host) gossh.HostKeyCallback {
return func(hostname string, remote net.Addr, key gossh.PublicKey) error {
if len(host.HostKey) == 0 {
log.Println("Discovering host fingerprint...")
return db.Model(host).Update("HostKey", key.Marshal()).Error
}
if !bytes.Equal(host.HostKey, key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}
}
func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
switch newChan.ChannelType() {
case "session":
default:
// TODO: handle direct-tcp
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
log.Printf("error: failed to reject channel: %v", err)
}
return
}
if !bytes.Equal(d.host.HostKey, key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
actx := ctx.Value(authContextKey).(*authContext)
switch actx.userType() {
case UserTypeBastion:
log.Printf("New connection(bastion): sshUser=%q remote=%q local=%q dbUser=id:%q,email:%s", conn.User(), conn.RemoteAddr(), conn.LocalAddr(), actx.user.ID, actx.user.Email)
host, clientConfig, err := bastionConfig(ctx)
if err != nil {
ch, _, err2 := newChan.Accept()
if err2 != nil {
return
}
fmt.Fprintf(ch, "error: %v\n", err)
// FIXME: force close all channels
return
}
sess := Session{
UserID: actx.user.ID,
HostID: host.ID,
Status: SessionStatusActive,
}
if err = actx.db.Create(&sess).Error; err != nil {
ch, _, err2 := newChan.Accept()
if err2 != nil {
return
}
fmt.Fprintf(ch, "error: %v\n", err)
return
}
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
Addr: host.Addr,
ClientConfig: clientConfig,
})
now := time.Now()
sessUpdate := Session{
Status: SessionStatusClosed,
ErrMsg: fmt.Sprintf("%v", err),
StoppedAt: &now,
}
switch sessUpdate.ErrMsg {
case "lch closed the connection", "rch closed the connection":
sessUpdate.ErrMsg = ""
}
actx.db.Model(&sess).Updates(&sessUpdate)
default: // shell
ssh.DefaultChannelHandler(srv, conn, newChan, ctx)
}
return nil
}
// DynamicHostKey returns a function for use in
// ClientConfig.HostKeyCallback to dynamically learn or accept host key.
func DynamicHostKey(db *gorm.DB, host *Host) gossh.HostKeyCallback {
// FIXME: forward interactively the host key checking
hk := &dynamicHostKey{db, host}
return hk.check
func bastionConfig(ctx ssh.Context) (*Host, *gossh.ClientConfig, error) {
actx := ctx.Value(authContextKey).(*authContext)
host, err := HostByName(actx.db, actx.inputUsername)
if err != nil {
return nil, nil, err
}
clientConfig, err := host.clientConfig(dynamicHostKey(actx.db, host))
if err != nil {
return nil, nil, err
}
var tmpUser User
if err = actx.db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", actx.user.ID).First(&tmpUser).Error; err != nil {
return nil, nil, err
}
var tmpHost Host
if err = actx.db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err != nil {
return nil, nil, err
}
action, err2 := CheckACLs(tmpUser, tmpHost)
if err2 != nil {
return nil, nil, err2
}
HostDecrypt(actx.globalContext.String("aes-key"), host)
SSHKeyDecrypt(actx.globalContext.String("aes-key"), host.SSHKey)
switch action {
case ACLActionAllow:
case ACLActionDeny:
return nil, nil, fmt.Errorf("you don't have permission to that host")
default:
return nil, nil, fmt.Errorf("invalid ACL action: %q", action)
}
return host, clientConfig, nil
}
func shellHandler(s ssh.Session) {
actx := s.Context().Value(authContextKey).(*authContext)
if actx.userType() != UserTypeHealthcheck {
log.Printf("New connection(shell): sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), actx.user.ID, actx.user.Email)
}
if actx.err != nil {
fmt.Fprintf(s, "error: %v\n", actx.err)
return
}
if actx.message != "" {
fmt.Fprint(s, actx.message)
}
switch actx.userType() {
case UserTypeHealthcheck:
fmt.Fprintln(s, "OK")
return
case UserTypeShell:
if err := shell(s); err != nil {
fmt.Fprintf(s, "error: %v\n", err)
}
return
case UserTypeInvite:
// do nothing (message was printed at the beginning of the function)
return
}
panic("should not happen")
}
func passwordAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PasswordHandler {
return func(ctx ssh.Context, pass string) bool {
actx := &authContext{
db: db,
inputUsername: ctx.User(),
globalContext: globalContext,
authMethod: "password",
}
actx.authSuccess = actx.userType() == UserTypeHealthcheck
ctx.SetValue(authContextKey, actx)
return actx.authSuccess
}
}
func publicKeyAuthHandler(db *gorm.DB, globalContext *cli.Context) ssh.PublicKeyHandler {
return func(ctx ssh.Context, key ssh.PublicKey) bool {
actx := &authContext{
db: db,
inputUsername: ctx.User(),
globalContext: globalContext,
authMethod: "pubkey",
authSuccess: true,
}
ctx.SetValue(authContextKey, actx)
// lookup user by key
db.Where("authorized_key = ?", string(gossh.MarshalAuthorizedKey(key))).First(&actx.userKey)
if actx.userKey.UserID > 0 {
db.Preload("Roles").Where("id = ?", actx.userKey.UserID).First(&actx.user)
if actx.userType() == "invite" {
actx.err = fmt.Errorf("invites are only supported for new SSH keys; your ssh key is already associated with the user %q", actx.user.Email)
}
return true
}
// handle invite "links"
if actx.userType() == "invite" {
inputToken := strings.Split(actx.inputUsername, ":")[1]
if len(inputToken) > 0 {
db.Where("invite_token = ?", inputToken).First(&actx.user)
}
if actx.user.ID > 0 {
actx.userKey = UserKey{
UserID: actx.user.ID,
Key: key.Marshal(),
Comment: "created by sshportal",
AuthorizedKey: string(gossh.MarshalAuthorizedKey(key)),
}
db.Create(actx.userKey)
// token is only usable once
actx.user.InviteToken = ""
db.Model(actx.user).Updates(actx.user)
actx.message = fmt.Sprintf("Welcome %s!\n\nYour key is now associated with the user %q.\n", actx.user.Name, actx.user.Email)
} else {
actx.user = User{Name: "Anonymous"}
actx.err = errors.New("your token is invalid or expired")
}
return true
}
// fallback
actx.err = errors.New("unknown ssh key")
actx.user = User{Name: "Anonymous"}
return true
}
}