sshportal/pkg/bastionsession/bastionsession.go
2018-04-02 22:36:06 +02:00

203 lines
4.8 KiB
Go

package bastionsession
import (
"errors"
"io"
"log"
"os"
"strings"
"time"
"github.com/arkan/bastion/pkg/logchannel"
"github.com/moul/sshportal/pkg/logtunnel"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
type ForwardData struct {
DestinationHost string
DestinationPort uint32
SourceHost string
SourcePort uint32
}
type Config struct {
Addr string
Logs string
ClientConfig *gossh.ClientConfig
}
func MultiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []Config) error {
var lastClient *gossh.Client
switch newChan.ChannelType() {
case "session" :
lch, lreqs, err := newChan.Accept()
// TODO: defer clean closer
if err != nil {
// TODO: trigger event callback
return nil
}
// go through all the hops
for _, config := range configs {
var client *gossh.Client
if lastClient == nil {
client, err = gossh.Dial("tcp", config.Addr, config.ClientConfig)
} else {
rconn, err := lastClient.Dial("tcp", config.Addr)
if err != nil {
return err
}
ncc, chans, reqs, err := gossh.NewClientConn(rconn, config.Addr, config.ClientConfig)
if err != nil {
return err
}
client = gossh.NewClient(ncc, chans, reqs)
}
if err != nil {
return err
}
defer func() { _ = client.Close() }()
lastClient = client
}
rch, rreqs, err := lastClient.OpenChannel("session", []byte{})
if err != nil {
return err
}
user := conn.User()
// pipe everything
return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, newChan)
case "direct-tcpip":
lch, lreqs, err := newChan.Accept()
// TODO: defer clean closer
if err != nil {
// TODO: trigger event callback
return nil
}
// go through all the hops
for _, config := range configs {
var client *gossh.Client
if lastClient == nil {
client, err = gossh.Dial("tcp", config.Addr, config.ClientConfig)
} else {
rconn, err := lastClient.Dial("tcp", config.Addr)
if err != nil {
return err
}
ncc, chans, reqs, err := gossh.NewClientConn(rconn, config.Addr, config.ClientConfig)
if err != nil {
return err
}
client = gossh.NewClient(ncc, chans, reqs)
}
if err != nil {
return err
}
defer func() { _ = client.Close() }()
lastClient = client
}
d := logtunnel.ForwardData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
return err
}
rch, rreqs, err := lastClient.OpenChannel("direct-tcpip", newChan.ExtraData())
if err != nil {
return err
}
user := conn.User()
// pipe everything
return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, newChan)
default:
newChan.Reject(gossh.UnknownChannelType, "unsupported channel type")
return nil
}
}
func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocation string, user string, newChan gossh.NewChannel) error {
defer func() {
_ = lch.Close()
_ = rch.Close()
}()
errch := make(chan error, 1)
channeltype := newChan.ChannelType()
file_name := strings.Join([]string{logsLocation, "/", user, "-", channeltype, "-", time.Now().Format(time.RFC3339)}, "") // get user
f, err := os.OpenFile(file_name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0640)
defer f.Close()
if err != nil {
log.Fatalf("error: %v", err)
}
log.Printf("Session %v is recorded in %v", channeltype, file_name)
if channeltype == "session" {
wrappedlch := logchannel.New(lch, f)
go func() {
_, _ = io.Copy(wrappedlch, rch)
errch <- errors.New("lch closed the connection")
}()
go func() {
_, _ = io.Copy(rch, lch)
errch <- errors.New("rch closed the connection")
}()
}
if channeltype == "direct-tcpip" {
d := logtunnel.ForwardData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
return err
}
wrappedlch := logtunnel.New(lch, f, d.SourceHost)
wrappedrch := logtunnel.New(rch, f, d.DestinationHost)
go func() {
_, _ = io.Copy(wrappedlch, rch)
errch <- errors.New("lch closed the connection")
}()
go func() {
_, _ = io.Copy(wrappedrch, lch)
errch <- errors.New("rch closed the connection")
}()
}
for {
select {
case req := <-lreqs: // forward ssh requests from local to remote
if req == nil {
return nil
}
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
if req.Type == "exec" {
command := append(req.Payload, []byte("\n")...)
wrappedlch.LogWrite(command)
}
if err != nil {
return err
}
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
case req := <-rreqs: // forward ssh requests from remote to local
if req == nil {
return nil
}
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
if err2 := req.Reply(b, nil); err2 != nil {
return err2
}
case err := <-errch:
return err
}
}
}