mirror of
https://github.com/moul/sshportal.git
synced 2025-03-09 05:43:31 +08:00
Refactor sshportal: create a custom bastion session handler
This commit is contained in:
parent
4125bc2768
commit
072464928b
4 changed files with 447 additions and 71 deletions
177
main.go
177
main.go
|
@ -18,6 +18,8 @@ import (
|
|||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/urfave/cli"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/moul/sshportal/pkg/bastionsession"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -135,19 +137,19 @@ func server(c *cli.Context) error {
|
|||
if c.Bool("debug") {
|
||||
db.LogMode(true)
|
||||
}
|
||||
if err := dbInit(db); err != nil {
|
||||
if err = dbInit(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ssh server
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
shellHandler := func(s ssh.Session) {
|
||||
currentUser := s.Context().Value(userContextKey).(User)
|
||||
if s.User() != "healthcheck" {
|
||||
log.Printf("New connection: sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
|
||||
log.Printf("New connection(shell): sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
|
||||
}
|
||||
|
||||
if err := s.Context().Value(errorContextKey); err != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err)
|
||||
if err2 := s.Context().Value(errorContextKey); err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -160,74 +162,80 @@ func server(c *cli.Context) error {
|
|||
fmt.Fprintln(s, "OK")
|
||||
return
|
||||
case username == currentUser.Name || username == currentUser.Email || username == c.String("config-user"):
|
||||
if err := shell(c, s, s.Command(), db); err != nil {
|
||||
if err = shell(c, s, s.Command(), db); err != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err)
|
||||
}
|
||||
case strings.HasPrefix(username, "invite:"):
|
||||
return
|
||||
default:
|
||||
host, err := RemoteHostFromSession(s, db)
|
||||
if err != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err)
|
||||
// FIXME: print available hosts
|
||||
return
|
||||
}
|
||||
|
||||
// load up-to-date objects
|
||||
// FIXME: cache them or try not to load them
|
||||
var tmpUser User
|
||||
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", currentUser.ID).First(&tmpUser).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
var tmpHost Host
|
||||
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
|
||||
action, err2 := CheckACLs(tmpUser, tmpHost)
|
||||
if err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
|
||||
// decrypt key and password
|
||||
HostDecrypt(c.String("aes-key"), host)
|
||||
SSHKeyDecrypt(c.String("aes-key"), host.SSHKey)
|
||||
|
||||
switch action {
|
||||
case ACLActionAllow:
|
||||
sess := Session{
|
||||
UserID: currentUser.ID,
|
||||
HostID: host.ID,
|
||||
Status: SessionStatusActive,
|
||||
}
|
||||
if err2 := db.Create(&sess).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
sessUpdate := Session{}
|
||||
if err2 := proxy(s, host, DynamicHostKey(db, host)); err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
sessUpdate.ErrMsg = fmt.Sprintf("%v", err2)
|
||||
switch sessUpdate.ErrMsg {
|
||||
case "lch closed the connection", "rch closed the connection":
|
||||
sessUpdate.ErrMsg = ""
|
||||
}
|
||||
}
|
||||
sessUpdate.Status = SessionStatusClosed
|
||||
now := time.Now()
|
||||
sessUpdate.StoppedAt = &now
|
||||
db.Model(&sess).Updates(&sessUpdate)
|
||||
case ACLActionDeny:
|
||||
fmt.Fprintf(s, "You don't have permission to that host.\n")
|
||||
default:
|
||||
fmt.Fprintf(s, "error: invalid ACL action: %q\n", action)
|
||||
}
|
||||
|
||||
fmt.Fprintf(s, "error: invalid user\n")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
bastionHandler := func(s ssh.Session) {
|
||||
currentUser := s.Context().Value(userContextKey).(User)
|
||||
log.Printf("New connection(bastion): sshUser=%q remote=%q local=%q command=%q dbUser=id:%q,email:%s", s.User(), s.RemoteAddr(), s.LocalAddr(), s.Command(), currentUser.ID, currentUser.Email)
|
||||
var host *Host
|
||||
host, err = RemoteHostFromSession(s, db)
|
||||
if err != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err)
|
||||
// FIXME: print available hosts
|
||||
return
|
||||
}
|
||||
|
||||
// load up-to-date objects
|
||||
// FIXME: cache them or try not to load them
|
||||
var tmpUser User
|
||||
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", currentUser.ID).First(&tmpUser).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
var tmpHost Host
|
||||
if err2 := db.Preload("Groups").Preload("Groups.ACLs").Where("id = ?", host.ID).First(&tmpHost).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
|
||||
action, err2 := CheckACLs(tmpUser, tmpHost)
|
||||
if err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
|
||||
// decrypt key and password
|
||||
HostDecrypt(c.String("aes-key"), host)
|
||||
SSHKeyDecrypt(c.String("aes-key"), host.SSHKey)
|
||||
|
||||
switch action {
|
||||
case ACLActionAllow:
|
||||
sess := Session{
|
||||
UserID: currentUser.ID,
|
||||
HostID: host.ID,
|
||||
Status: SessionStatusActive,
|
||||
}
|
||||
if err2 := db.Create(&sess).Error; err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
return
|
||||
}
|
||||
sessUpdate := Session{}
|
||||
if err2 := proxy(s, host, DynamicHostKey(db, host)); err2 != nil {
|
||||
fmt.Fprintf(s, "error: %v\n", err2)
|
||||
sessUpdate.ErrMsg = fmt.Sprintf("%v", err2)
|
||||
switch sessUpdate.ErrMsg {
|
||||
case "lch closed the connection", "rch closed the connection":
|
||||
sessUpdate.ErrMsg = ""
|
||||
}
|
||||
}
|
||||
sessUpdate.Status = SessionStatusClosed
|
||||
now := time.Now()
|
||||
sessUpdate.StoppedAt = &now
|
||||
db.Model(&sess).Updates(&sessUpdate)
|
||||
case ACLActionDeny:
|
||||
fmt.Fprintf(s, "You don't have permission to that host.\n")
|
||||
default:
|
||||
fmt.Fprintf(s, "error: invalid ACL action: %q\n", action)
|
||||
}
|
||||
}
|
||||
|
||||
opts := []ssh.Option{}
|
||||
opts = append(opts, ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
|
||||
|
@ -289,12 +297,13 @@ func server(c *cli.Context) error {
|
|||
|
||||
opts = append(opts, func(srv *ssh.Server) error {
|
||||
var key SSHKey
|
||||
if err := SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
||||
if err = SSHKeysByIdentifiers(db, []string{"host"}).First(&key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
SSHKeyDecrypt(c.String("aes-key"), &key)
|
||||
|
||||
signer, err := gossh.ParsePrivateKey([]byte(key.PrivKey))
|
||||
var signer gossh.Signer
|
||||
signer, err = gossh.ParsePrivateKey([]byte(key.PrivKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -303,7 +312,35 @@ func server(c *cli.Context) error {
|
|||
})
|
||||
|
||||
log.Printf("info: SSH Server accepting connections on %s", c.String("bind-address"))
|
||||
return ssh.ListenAndServe(c.String("bind-address"), nil, opts...)
|
||||
ln, err := net.Listen("tcp", c.String("bind-address"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv := &ssh.Server{Addr: c.String("bind-address"), Handler: shellHandler}
|
||||
for _, opt := range opts {
|
||||
if err := srv.SetOption(opt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
srv.Version = fmt.Sprintf("sshportal-%s", Version)
|
||||
srv.ChannelHandler = func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if newChan.ChannelType() != "session" {
|
||||
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
|
||||
log.Printf("error: failed to reject channel: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// TODO: handle direct-tcp
|
||||
|
||||
currentUser := ctx.Value(userContextKey).(User)
|
||||
username := conn.User()
|
||||
if username == c.String("healthcheck-user") || username == currentUser.Name || username == currentUser.Email || username == c.String("config-user") || strings.HasPrefix(username, "invite:") {
|
||||
ssh.DefaultChannelHandler(srv, conn, newChan, ctx)
|
||||
} else {
|
||||
bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionHandler)
|
||||
}
|
||||
}
|
||||
return srv.Serve(ln)
|
||||
}
|
||||
|
||||
// perform a healthcheck test without requiring an ssh client or an ssh key (used for Docker's HEALTHCHECK)
|
||||
|
|
256
pkg/bastionsession/bastionsession.go
Normal file
256
pkg/bastionsession/bastionsession.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package bastionsession
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/anmitsu/go-shlex"
|
||||
"github.com/gliderlabs/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// maxSigBufSize is how many signals will be buffered
|
||||
// when there is no signal channel specified
|
||||
const maxSigBufSize = 128
|
||||
|
||||
func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, handler ssh.Handler) {
|
||||
if newChan.ChannelType() != "session" {
|
||||
newChan.Reject(gossh.UnknownChannelType, "unsupported channel type")
|
||||
return
|
||||
}
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
// TODO: trigger event callback
|
||||
return
|
||||
}
|
||||
if handler == nil {
|
||||
handler = srv.Handler
|
||||
}
|
||||
sess := &session{
|
||||
Channel: ch,
|
||||
conn: conn,
|
||||
handler: handler,
|
||||
ptyCb: srv.PtyCallback,
|
||||
maskedReqs: make(chan *gossh.Request, 5),
|
||||
ctx: ctx,
|
||||
}
|
||||
sess.ctx.SetValue("masked-reqs", sess.maskedReqs)
|
||||
//ssh.DefaultChannelHandler(srv, conn, ch, ctx)
|
||||
//return
|
||||
sess.handleRequests(reqs)
|
||||
}
|
||||
|
||||
type session struct {
|
||||
sync.Mutex
|
||||
gossh.Channel
|
||||
conn *gossh.ServerConn
|
||||
handler ssh.Handler
|
||||
handled bool
|
||||
exited bool
|
||||
pty *ssh.Pty
|
||||
winch chan ssh.Window
|
||||
env []string
|
||||
ptyCb ssh.PtyCallback
|
||||
cmd []string
|
||||
ctx ssh.Context
|
||||
sigCh chan<- ssh.Signal
|
||||
sigBuf []ssh.Signal
|
||||
maskedReqs chan *gossh.Request
|
||||
}
|
||||
|
||||
func (sess *session) Write(p []byte) (n int, err error) {
|
||||
if sess.pty != nil {
|
||||
m := len(p)
|
||||
// normalize \n to \r\n when pty is accepted.
|
||||
// this is a hardcoded shortcut since we don't support terminal modes.
|
||||
p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
|
||||
p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
|
||||
n, err = sess.Channel.Write(p)
|
||||
if n > m {
|
||||
n = m
|
||||
}
|
||||
return
|
||||
}
|
||||
return sess.Channel.Write(p)
|
||||
}
|
||||
|
||||
func (sess *session) PublicKey() ssh.PublicKey {
|
||||
sessionkey := sess.ctx.Value(ssh.ContextKeyPublicKey)
|
||||
if sessionkey == nil {
|
||||
return nil
|
||||
}
|
||||
return sessionkey.(ssh.PublicKey)
|
||||
}
|
||||
|
||||
func (sess *session) Permissions() ssh.Permissions {
|
||||
// use context permissions because its properly
|
||||
// wrapped and easier to dereference
|
||||
perms := sess.ctx.Value(ssh.ContextKeyPermissions).(*ssh.Permissions)
|
||||
return *perms
|
||||
}
|
||||
|
||||
func (sess *session) Context() context.Context {
|
||||
return sess.ctx
|
||||
}
|
||||
|
||||
func (sess *session) Exit(code int) error {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
|
||||
if sess.exited {
|
||||
return errors.New("Session.Exit called multiple times")
|
||||
}
|
||||
sess.exited = true
|
||||
|
||||
status := struct{ Status uint32 }{uint32(code)}
|
||||
_, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
close(sess.maskedReqs)
|
||||
|
||||
return sess.Close()
|
||||
}
|
||||
|
||||
func (sess *session) User() string {
|
||||
return sess.conn.User()
|
||||
}
|
||||
|
||||
func (sess *session) RemoteAddr() net.Addr {
|
||||
return sess.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (sess *session) LocalAddr() net.Addr {
|
||||
return sess.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (sess *session) Environ() []string {
|
||||
return append([]string(nil), sess.env...)
|
||||
}
|
||||
|
||||
func (sess *session) Command() []string {
|
||||
return append([]string(nil), sess.cmd...)
|
||||
}
|
||||
|
||||
func (sess *session) Pty() (ssh.Pty, <-chan ssh.Window, bool) {
|
||||
if sess.pty != nil {
|
||||
return *sess.pty, sess.winch, true
|
||||
}
|
||||
return ssh.Pty{}, sess.winch, false
|
||||
}
|
||||
|
||||
func (sess *session) MaskedReqs() chan *gossh.Request {
|
||||
return sess.maskedReqs
|
||||
}
|
||||
|
||||
func (sess *session) Signals(c chan<- ssh.Signal) {
|
||||
sess.Lock()
|
||||
defer sess.Unlock()
|
||||
sess.sigCh = c
|
||||
if len(sess.sigBuf) > 0 {
|
||||
go func() {
|
||||
for _, sig := range sess.sigBuf {
|
||||
sess.sigCh <- sig
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
|
||||
for req := range reqs {
|
||||
addToMaskedReqs := true
|
||||
switch req.Type {
|
||||
case "shell", "exec":
|
||||
if sess.handled {
|
||||
req.Reply(false, nil)
|
||||
addToMaskedReqs = false
|
||||
continue
|
||||
}
|
||||
sess.handled = true
|
||||
// req.Reply(true, nil) // let the proxy reply
|
||||
|
||||
var payload = struct{ Value string }{}
|
||||
gossh.Unmarshal(req.Payload, &payload)
|
||||
sess.cmd, _ = shlex.Split(payload.Value, true)
|
||||
go func() {
|
||||
sess.handler(sess)
|
||||
sess.Exit(0)
|
||||
}()
|
||||
case "env":
|
||||
if sess.handled {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
var kv struct{ Key, Value string }
|
||||
gossh.Unmarshal(req.Payload, &kv)
|
||||
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
|
||||
req.Reply(true, nil)
|
||||
case "signal":
|
||||
var payload struct{ Signal string }
|
||||
gossh.Unmarshal(req.Payload, &payload)
|
||||
sess.Lock()
|
||||
if sess.sigCh != nil {
|
||||
sess.sigCh <- ssh.Signal(payload.Signal)
|
||||
} else {
|
||||
if len(sess.sigBuf) < maxSigBufSize {
|
||||
sess.sigBuf = append(sess.sigBuf, ssh.Signal(payload.Signal))
|
||||
}
|
||||
}
|
||||
sess.Unlock()
|
||||
case "pty-req":
|
||||
if sess.handled || sess.pty != nil {
|
||||
req.Reply(false, nil)
|
||||
addToMaskedReqs = false
|
||||
continue
|
||||
}
|
||||
ptyReq, ok := parsePtyRequest(req.Payload)
|
||||
if !ok {
|
||||
req.Reply(false, nil)
|
||||
addToMaskedReqs = false
|
||||
continue
|
||||
}
|
||||
if sess.ptyCb != nil {
|
||||
ok := sess.ptyCb(sess.ctx, ptyReq)
|
||||
if !ok {
|
||||
req.Reply(false, nil)
|
||||
addToMaskedReqs = false
|
||||
continue
|
||||
}
|
||||
}
|
||||
sess.pty = &ptyReq
|
||||
sess.winch = make(chan ssh.Window, 1)
|
||||
sess.winch <- ptyReq.Window
|
||||
defer func() {
|
||||
// when reqs is closed
|
||||
close(sess.winch)
|
||||
}()
|
||||
//req.Reply(ok, nil) // let the proxy reply
|
||||
case "window-change":
|
||||
if sess.pty == nil {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
win, ok := parseWinchRequest(req.Payload)
|
||||
if ok {
|
||||
sess.pty.Window = win
|
||||
sess.winch <- win
|
||||
}
|
||||
req.Reply(ok, nil)
|
||||
case "auth-agent-req@openssh.com":
|
||||
// TODO: option/callback to allow agent forwarding
|
||||
ssh.SetAgentRequested(sess.ctx)
|
||||
req.Reply(true, nil)
|
||||
default:
|
||||
// TODO: debug log
|
||||
}
|
||||
|
||||
if addToMaskedReqs {
|
||||
sess.maskedReqs <- req
|
||||
}
|
||||
}
|
||||
}
|
79
pkg/bastionsession/util.go
Normal file
79
pkg/bastionsession/util.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package bastionsession
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
func parsePtyRequest(s []byte) (pty ssh.Pty, ok bool) {
|
||||
term, s, ok := parseString(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
width32, s, ok := parseUint32(s)
|
||||
if width32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
if height32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
pty = ssh.Pty{
|
||||
Term: term,
|
||||
Window: ssh.Window{
|
||||
Width: int(width32),
|
||||
Height: int(height32),
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseWinchRequest(s []byte) (win ssh.Window, ok bool) {
|
||||
width32, s, ok := parseUint32(s)
|
||||
if width32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
if height32 < 1 {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
win = ssh.Window{
|
||||
Width: int(width32),
|
||||
Height: int(height32),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseString(in []byte) (out string, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
length := binary.BigEndian.Uint32(in)
|
||||
if uint32(len(in)) < 4+length {
|
||||
return
|
||||
}
|
||||
out = string(in[4 : 4+length])
|
||||
rest = in[4+length:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseUint32(in []byte) (uint32, []byte, bool) {
|
||||
if len(in) < 4 {
|
||||
return 0, nil, false
|
||||
}
|
||||
return binary.BigEndian.Uint32(in), in[4:], true
|
||||
}
|
6
proxy.go
6
proxy.go
|
@ -28,7 +28,11 @@ func proxy(s ssh.Session, host *Host, hk gossh.HostKeyCallback) error {
|
|||
}
|
||||
|
||||
log.Println("SSH Connection established")
|
||||
return pipe(s.MaskedReqs(), rreqs, s, rch)
|
||||
maskedReqs := s.Context().Value("masked-reqs")
|
||||
if maskedReqs == nil {
|
||||
return fmt.Errorf("ctx.maskedReqs doesn't exist")
|
||||
}
|
||||
return pipe(maskedReqs.(chan *gossh.Request), rreqs, s, rch)
|
||||
}
|
||||
|
||||
func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel) error {
|
||||
|
|
Loading…
Reference in a new issue