diff --git a/pkg/bastion/session.go b/pkg/bastion/session.go index 3ec2d81..9292f9e 100644 --- a/pkg/bastion/session.go +++ b/pkg/bastion/session.go @@ -2,6 +2,7 @@ package bastion // import "moul.io/sshportal/pkg/bastion" import ( "errors" + "fmt" "io" "log" "os" @@ -19,7 +20,7 @@ type sessionConfig struct { ClientConfig *gossh.ClientConfig } -func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []sessionConfig) error { +func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context, configs []sessionConfig, sessionID uint) error { var lastClient *gossh.Client switch newChan.ChannelType() { case "session": @@ -59,8 +60,10 @@ func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh. return err } user := conn.User() + actx := ctx.Value(authContextKey).(*authContext) + username := actx.user.Name // pipe everything - return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, newChan) + return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, username, sessionID, newChan) case "direct-tcpip": lch, lreqs, err := newChan.Accept() // TODO: defer clean closer @@ -102,8 +105,10 @@ func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh. return err } user := conn.User() + actx := ctx.Value(authContextKey).(*authContext) + username := actx.user.Name // pipe everything - return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, newChan) + return pipe(lreqs, rreqs, lch, rch, configs[len(configs)-1].Logs, user, username, sessionID, newChan) default: if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil { log.Printf("failed to reject chan: %v", err) @@ -112,7 +117,7 @@ func multiChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh. } } -func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocation string, user string, newChan gossh.NewChannel) error { +func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocation string, user string, username string, sessionID uint, newChan gossh.NewChannel) error { defer func() { _ = lch.Close() _ = rch.Close() @@ -121,7 +126,7 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati errch := make(chan error, 1) channeltype := newChan.ChannelType() - filename := strings.Join([]string{logsLocation, "/", user, "-", channeltype, "-", time.Now().Format(time.RFC3339)}, "") // get user + filename := strings.Join([]string{logsLocation, "/", user, "-", username, "-", channeltype, "-", fmt.Sprint(sessionID), "-", time.Now().Format(time.RFC3339)}, "") // get user f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0640) defer func() { _ = f.Close() diff --git a/pkg/bastion/ssh.go b/pkg/bastion/ssh.go index 8641c01..07ba117 100644 --- a/pkg/bastion/ssh.go +++ b/pkg/bastion/ssh.go @@ -148,9 +148,8 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh _ = ch.Close() return } - go func() { - err = multiChannelHandler(srv, conn, newChan, ctx, sessionConfigs) + err = multiChannelHandler(srv, conn, newChan, ctx, sessionConfigs, sess.ID) if err != nil { log.Printf("Error: %v", err) }