Implement proxied connections

The feature is implemented as follows:
- when creating a host, there is a possiblity to add a "hop"
- hops are referend them with the name of the host in sshportal
- the hop ID is then saved in the DB in the hosts table
- when connecting to a host, sshportal will recurse through all the
  possible hops of a host (allowing chained proxies)
This commit is contained in:
Mathieu Pasquet 2018-02-22 17:50:55 +01:00
parent e6a02a85f0
commit 75c6840ecd
5 changed files with 102 additions and 29 deletions

2
db.go
View file

@ -63,6 +63,8 @@ type Host struct {
HostKey []byte `sql:"size:10000" valid:"optional"`
Groups []*HostGroup `gorm:"many2many:host_host_groups;"`
Comment string `valid:"optional"`
Hop *Host
HopID uint
}
// UserKey defines a user public key used by sshportal to identify the user

View file

@ -458,6 +458,30 @@ func dbInit(db *gorm.DB) error {
Rollback: func(tx *gorm.DB) error {
return fmt.Errorf("not implemented")
},
}, {
ID: "29",
Migrate: func(tx *gorm.DB) error {
type Host struct {
// FIXME: use uuid for ID
gorm.Model
Name string `gorm:"size:32"`
Addr string
User string
Password string
URL string
SSHKey *SSHKey `gorm:"ForeignKey:SSHKeyID"`
SSHKeyID uint `gorm:"index"`
HostKey []byte `sql:"size:10000"`
Groups []*HostGroup `gorm:"many2many:host_host_groups;"`
Comment string
Hop *Host
HopID uint
}
return tx.AutoMigrate(&Host{}).Error
},
Rollback: func(tx *gorm.DB) error {
return fmt.Errorf("not implemented")
},
},
})
if err := m.Migrate(); err != nil {

View file

@ -3,13 +3,13 @@ package bastionsession
import (
"errors"
"io"
"log"
"os"
"strings"
"time"
"os"
"log"
"github.com/gliderlabs/ssh"
"github.com/arkan/bastion/pkg/logchannel"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
@ -19,7 +19,7 @@ type Config struct {
ClientConfig *gossh.ClientConfig
}
func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, config Config) error {
func MultiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []Config) error {
if newChan.ChannelType() != "session" {
newChan.Reject(gossh.UnknownChannelType, "unsupported channel type")
return nil
@ -31,19 +31,37 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
return nil
}
// open client channel
rconn, err := gossh.Dial("tcp", config.Addr, config.ClientConfig)
if err != nil {
return err
var lastClient *gossh.Client
// go through all the hops
for _, config := range configs {
var client *gossh.Client
if lastClient == nil {
client, err = gossh.Dial("tcp", config.Addr, config.ClientConfig)
} else {
rconn, err := lastClient.Dial("tcp", config.Addr)
if err != nil {
return err
}
ncc, chans, reqs, err := gossh.NewClientConn(rconn, config.Addr, config.ClientConfig)
if err != nil {
return err
}
client = gossh.NewClient(ncc, chans, reqs)
}
if err != nil {
return err
}
defer func() { _ = client.Close() }()
lastClient = client
}
defer func() { _ = rconn.Close() }()
rch, rreqs, err := rconn.OpenChannel("session", []byte{})
rch, rreqs, err := lastClient.OpenChannel("session", []byte{})
if err != nil {
return err
}
user := conn.User()
// pipe everything
return pipe(lreqs, rreqs, lch, rch, config.Logs, user)
return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user)
}
func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocation string, user string) error {
@ -57,7 +75,7 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
f, err := os.OpenFile(file_name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0640)
if err != nil {
log.Fatalf("error: %v", err)
}
}
log.Printf("Session is recorded in %v", file_name)
wrappedlch := logchannel.New(lch, f)
@ -65,9 +83,9 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
_, _ = io.Copy(wrappedlch, rch)
errch <- errors.New("lch closed the connection")
}()
defer f.Close()
go func() {
_, _ = io.Copy(rch, lch)
errch <- errors.New("rch closed the connection")

View file

@ -642,6 +642,7 @@ GLOBAL OPTIONS:
cli.StringFlag{Name: "password, p", Usage: "If present, sshportal will use password-based authentication"},
cli.StringFlag{Name: "comment, c"},
cli.StringFlag{Name: "key, k", Usage: "`KEY` to use for authentication"},
cli.StringFlag{Name: "hop, o", Usage: "Hop to use for connecting to the server"},
cli.StringSliceFlag{Name: "group, g", Usage: "Assigns the host to `HOSTGROUPS` (default: \"default\")"},
},
Action: func(c *cli.Context) error {
@ -665,7 +666,13 @@ GLOBAL OPTIONS:
host.Password = c.String("password")
}
host.Name = strings.Split(host.Hostname(), ".")[0]
if c.String("hop") != "" {
hop, err := HostByName(db, c.String("hop"))
if err != nil {
return err
}
host.Hop = hop
}
if c.String("name") != "" {
host.Name = c.String("name")
}
@ -776,7 +783,7 @@ GLOBAL OPTIONS:
}
table := tablewriter.NewWriter(s)
table.SetHeader([]string{"ID", "Name", "URL", "Key", "Groups", "Updated", "Created", "Comment"})
table.SetHeader([]string{"ID", "Name", "URL", "Key", "Groups", "Updated", "Created", "Comment", "Hop"})
table.SetBorder(false)
table.SetCaption(true, fmt.Sprintf("Total: %d hosts.", len(hosts)))
for _, host := range hosts {
@ -790,6 +797,14 @@ GLOBAL OPTIONS:
for _, hostGroup := range host.Groups {
groupNames = append(groupNames, hostGroup.Name)
}
var hop string
if host.HopID != 0 {
var hopHost Host
db.Model(&host).Related(&hopHost, "HopID")
hop = hopHost.Name
} else {
hop = ""
}
table.Append([]string{
fmt.Sprintf("%d", host.ID),
host.Name,
@ -799,6 +814,7 @@ GLOBAL OPTIONS:
humanize.Time(host.UpdatedAt),
humanize.Time(host.CreatedAt),
host.Comment,
hop,
//FIXME: add some stats about last access time etc
})
}

37
ssh.go
View file

@ -113,16 +113,33 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
switch host.Scheme() {
case BastionSchemeSSH:
clientConfig, err := bastionClientConfig(ctx, host)
if err != nil {
ch, _, err2 := newChan.Accept()
sessionConfigs := make([]bastionsession.Config, 0)
currentHost := host
for currentHost != nil {
clientConfig, err2 := bastionClientConfig(ctx, currentHost)
if err2 != nil {
ch, _, err3 := newChan.Accept()
if err3 != nil {
return
}
fmt.Fprintf(ch, "error: %v\n", err2)
// FIXME: force close all channels
_ = ch.Close()
return
}
fmt.Fprintf(ch, "error: %v\n", err)
// FIXME: force close all channels
_ = ch.Close()
return
sessionConfigs = append([]bastionsession.Config{{
Addr: currentHost.DialAddr(),
ClientConfig: clientConfig,
Logs: actx.config.logsLocation,
}}, sessionConfigs...)
if currentHost.HopID != 0 {
var newHost Host
actx.db.Model(currentHost).Related(&newHost, "HopID")
hostname := newHost.Name
currentHost, _ = HostByName(actx.db, hostname)
} else {
currentHost = nil
}
}
sess := Session{
@ -140,11 +157,7 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
return
}
err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{
Addr: host.DialAddr(),
ClientConfig: clientConfig,
Logs: actx.config.logsLocation,
})
err = bastionsession.MultiChannelHandler(srv, conn, newChan, ctx, sessionConfigs)
now := time.Now()
sessUpdate := Session{