diff --git a/db.go b/db.go index ff2195f..071587c 100644 --- a/db.go +++ b/db.go @@ -63,6 +63,8 @@ type Host struct { HostKey []byte `sql:"size:10000" valid:"optional"` Groups []*HostGroup `gorm:"many2many:host_host_groups;"` Comment string `valid:"optional"` + Hop *Host + HopID uint } // UserKey defines a user public key used by sshportal to identify the user diff --git a/dbinit.go b/dbinit.go index 4c42b19..bdfe7f8 100644 --- a/dbinit.go +++ b/dbinit.go @@ -458,6 +458,30 @@ func dbInit(db *gorm.DB) error { Rollback: func(tx *gorm.DB) error { return fmt.Errorf("not implemented") }, + }, { + ID: "29", + Migrate: func(tx *gorm.DB) error { + type Host struct { + // FIXME: use uuid for ID + gorm.Model + Name string `gorm:"size:32"` + Addr string + User string + Password string + URL string + SSHKey *SSHKey `gorm:"ForeignKey:SSHKeyID"` + SSHKeyID uint `gorm:"index"` + HostKey []byte `sql:"size:10000"` + Groups []*HostGroup `gorm:"many2many:host_host_groups;"` + Comment string + Hop *Host + HopID uint + } + return tx.AutoMigrate(&Host{}).Error + }, + Rollback: func(tx *gorm.DB) error { + return fmt.Errorf("not implemented") + }, }, }) if err := m.Migrate(); err != nil { diff --git a/pkg/bastionsession/bastionsession.go b/pkg/bastionsession/bastionsession.go index 5d1ebfc..76465a0 100644 --- a/pkg/bastionsession/bastionsession.go +++ b/pkg/bastionsession/bastionsession.go @@ -3,13 +3,13 @@ package bastionsession import ( "errors" "io" + "log" + "os" "strings" "time" - "os" - "log" - - "github.com/gliderlabs/ssh" + "github.com/arkan/bastion/pkg/logchannel" + "github.com/gliderlabs/ssh" gossh "golang.org/x/crypto/ssh" ) @@ -19,7 +19,7 @@ type Config struct { ClientConfig *gossh.ClientConfig } -func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, config Config) error { +func MultiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []Config) error { if newChan.ChannelType() != "session" { newChan.Reject(gossh.UnknownChannelType, "unsupported channel type") return nil @@ -31,19 +31,37 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh return nil } - // open client channel - rconn, err := gossh.Dial("tcp", config.Addr, config.ClientConfig) - if err != nil { - return err + var lastClient *gossh.Client + + // go through all the hops + for _, config := range configs { + var client *gossh.Client + if lastClient == nil { + client, err = gossh.Dial("tcp", config.Addr, config.ClientConfig) + } else { + rconn, err := lastClient.Dial("tcp", config.Addr) + if err != nil { + return err + } + ncc, chans, reqs, err := gossh.NewClientConn(rconn, config.Addr, config.ClientConfig) + if err != nil { + return err + } + client = gossh.NewClient(ncc, chans, reqs) + } + if err != nil { + return err + } + defer func() { _ = client.Close() }() + lastClient = client } - defer func() { _ = rconn.Close() }() - rch, rreqs, err := rconn.OpenChannel("session", []byte{}) + rch, rreqs, err := lastClient.OpenChannel("session", []byte{}) if err != nil { return err } user := conn.User() // pipe everything - return pipe(lreqs, rreqs, lch, rch, config.Logs, user) + return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user) } func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocation string, user string) error { @@ -57,7 +75,7 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati f, err := os.OpenFile(file_name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0640) if err != nil { log.Fatalf("error: %v", err) - } + } log.Printf("Session is recorded in %v", file_name) wrappedlch := logchannel.New(lch, f) @@ -65,9 +83,9 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati _, _ = io.Copy(wrappedlch, rch) errch <- errors.New("lch closed the connection") }() - + defer f.Close() - + go func() { _, _ = io.Copy(rch, lch) errch <- errors.New("rch closed the connection") diff --git a/shell.go b/shell.go index 178d67c..ac69b53 100644 --- a/shell.go +++ b/shell.go @@ -642,6 +642,7 @@ GLOBAL OPTIONS: cli.StringFlag{Name: "password, p", Usage: "If present, sshportal will use password-based authentication"}, cli.StringFlag{Name: "comment, c"}, cli.StringFlag{Name: "key, k", Usage: "`KEY` to use for authentication"}, + cli.StringFlag{Name: "hop, o", Usage: "Hop to use for connecting to the server"}, cli.StringSliceFlag{Name: "group, g", Usage: "Assigns the host to `HOSTGROUPS` (default: \"default\")"}, }, Action: func(c *cli.Context) error { @@ -665,7 +666,13 @@ GLOBAL OPTIONS: host.Password = c.String("password") } host.Name = strings.Split(host.Hostname(), ".")[0] - + if c.String("hop") != "" { + hop, err := HostByName(db, c.String("hop")) + if err != nil { + return err + } + host.Hop = hop + } if c.String("name") != "" { host.Name = c.String("name") } @@ -776,7 +783,7 @@ GLOBAL OPTIONS: } table := tablewriter.NewWriter(s) - table.SetHeader([]string{"ID", "Name", "URL", "Key", "Groups", "Updated", "Created", "Comment"}) + table.SetHeader([]string{"ID", "Name", "URL", "Key", "Groups", "Updated", "Created", "Comment", "Hop"}) table.SetBorder(false) table.SetCaption(true, fmt.Sprintf("Total: %d hosts.", len(hosts))) for _, host := range hosts { @@ -790,6 +797,14 @@ GLOBAL OPTIONS: for _, hostGroup := range host.Groups { groupNames = append(groupNames, hostGroup.Name) } + var hop string + if host.HopID != 0 { + var hopHost Host + db.Model(&host).Related(&hopHost, "HopID") + hop = hopHost.Name + } else { + hop = "" + } table.Append([]string{ fmt.Sprintf("%d", host.ID), host.Name, @@ -799,6 +814,7 @@ GLOBAL OPTIONS: humanize.Time(host.UpdatedAt), humanize.Time(host.CreatedAt), host.Comment, + hop, //FIXME: add some stats about last access time etc }) } diff --git a/ssh.go b/ssh.go index 52c360f..9e632dc 100644 --- a/ssh.go +++ b/ssh.go @@ -113,16 +113,33 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh switch host.Scheme() { case BastionSchemeSSH: - clientConfig, err := bastionClientConfig(ctx, host) - if err != nil { - ch, _, err2 := newChan.Accept() + sessionConfigs := make([]bastionsession.Config, 0) + currentHost := host + for currentHost != nil { + clientConfig, err2 := bastionClientConfig(ctx, currentHost) if err2 != nil { + ch, _, err3 := newChan.Accept() + if err3 != nil { + return + } + fmt.Fprintf(ch, "error: %v\n", err2) + // FIXME: force close all channels + _ = ch.Close() return } - fmt.Fprintf(ch, "error: %v\n", err) - // FIXME: force close all channels - _ = ch.Close() - return + sessionConfigs = append([]bastionsession.Config{{ + Addr: currentHost.DialAddr(), + ClientConfig: clientConfig, + Logs: actx.config.logsLocation, + }}, sessionConfigs...) + if currentHost.HopID != 0 { + var newHost Host + actx.db.Model(currentHost).Related(&newHost, "HopID") + hostname := newHost.Name + currentHost, _ = HostByName(actx.db, hostname) + } else { + currentHost = nil + } } sess := Session{ @@ -140,11 +157,7 @@ func channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh return } - err = bastionsession.ChannelHandler(srv, conn, newChan, ctx, bastionsession.Config{ - Addr: host.DialAddr(), - ClientConfig: clientConfig, - Logs: actx.config.logsLocation, - }) + err = bastionsession.MultiChannelHandler(srv, conn, newChan, ctx, sessionConfigs) now := time.Now() sessUpdate := Session{