Merge pull request #183 from jrrdev/exit_fix

This commit is contained in:
Manfred Touron 2020-07-01 14:16:14 +02:00 committed by GitHub
commit 7f3ea431a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 29 deletions

View file

@ -1,7 +1,6 @@
package bastion // import "moul.io/sshportal/pkg/bastion" package bastion // import "moul.io/sshportal/pkg/bastion"
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -124,6 +123,7 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
}() }()
errch := make(chan error, 1) errch := make(chan error, 1)
quit := make(chan string, 1)
channeltype := newChan.ChannelType() channeltype := newChan.ChannelType()
filename := strings.Join([]string{logsLocation, "/", user, "-", username, "-", channeltype, "-", fmt.Sprint(sessionID), "-", time.Now().Format(time.RFC3339)}, "") // get user filename := strings.Join([]string{logsLocation, "/", user, "-", username, "-", channeltype, "-", fmt.Sprint(sessionID), "-", time.Now().Format(time.RFC3339)}, "") // get user
@ -139,15 +139,15 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
log.Printf("Session %v is recorded in %v", channeltype, filename) log.Printf("Session %v is recorded in %v", channeltype, filename)
if channeltype == "session" { if channeltype == "session" {
wrappedlch := logchannel.New(lch, f) wrappedlch := logchannel.New(lch, f)
go func() { go func(quit chan string) {
_, _ = io.Copy(wrappedlch, rch) _, _ = io.Copy(wrappedlch, rch)
errch <- errors.New("lch closed the connection") quit <- "rch"
}() }(quit)
go func() { go func(quit chan string) {
_, _ = io.Copy(rch, lch) _, _ = io.Copy(rch, lch)
errch <- errors.New("rch closed the connection") quit <- "lch"
}() }(quit)
} }
if channeltype == "direct-tcpip" { if channeltype == "direct-tcpip" {
d := logTunnelForwardData{} d := logTunnelForwardData{}
@ -156,23 +156,19 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
} }
wrappedlch := newLogTunnel(lch, f, d.SourceHost) wrappedlch := newLogTunnel(lch, f, d.SourceHost)
wrappedrch := newLogTunnel(rch, f, d.DestinationHost) wrappedrch := newLogTunnel(rch, f, d.DestinationHost)
go func() { go func(quit chan string) {
_, _ = io.Copy(wrappedlch, rch) _, _ = io.Copy(wrappedlch, rch)
errch <- errors.New("lch closed the connection") quit <- "rch"
}() }(quit)
go func() { go func(quit chan string) {
_, _ = io.Copy(wrappedrch, lch) _, _ = io.Copy(wrappedrch, lch)
errch <- errors.New("rch closed the connection") quit <- "lch"
}() }(quit)
} }
for { go func(quit chan string) {
select { for req := range lreqs {
case req := <-lreqs: // forward ssh requests from local to remote
if req == nil {
return nil
}
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload) b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
if req.Type == "exec" { if req.Type == "exec" {
wrappedlch := logchannel.New(lch, f) wrappedlch := logchannel.New(lch, f)
@ -183,24 +179,58 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
} }
if err != nil { if err != nil {
return err errch <- err
} }
if err2 := req.Reply(b, nil); err2 != nil { if err2 := req.Reply(b, nil); err2 != nil {
return err2 errch <- err2
}
case req := <-rreqs: // forward ssh requests from remote to local
if req == nil {
return nil
} }
}
quit <- "lreqs"
}(quit)
go func(quit chan string) {
for req := range rreqs {
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload) b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil { if err != nil {
return err errch <- err
} }
if err2 := req.Reply(b, nil); err2 != nil { if err2 := req.Reply(b, nil); err2 != nil {
return err2 errch <- err2
} }
}
quit <- "rreqs"
}(quit)
lchEOF, rchEOF, lchClosed, rchClosed := false, false, false, false
for {
select {
case err := <-errch: case err := <-errch:
return err return err
case q := <-quit:
switch q {
case "lch":
lchEOF = true
_ = rch.CloseWrite()
case "rch":
rchEOF = true
_ = lch.CloseWrite()
case "lreqs":
lchClosed = true
case "rreqs":
rchClosed = true
}
if lchEOF && lchClosed && !rchClosed {
rch.Close()
}
if rchEOF && rchClosed && !lchClosed {
lch.Close()
}
if lchEOF && rchEOF && lchClosed && rchClosed {
return nil
}
} }
} }
} }

View file

@ -160,8 +160,7 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
ErrMsg: fmt.Sprintf("%v", err), ErrMsg: fmt.Sprintf("%v", err),
StoppedAt: &now, StoppedAt: &now,
} }
switch sessUpdate.ErrMsg { if err == nil {
case "lch closed the connection", "rch closed the connection":
sessUpdate.ErrMsg = "" sessUpdate.ErrMsg = ""
} }
actx.db.Model(&sess).Updates(&sessUpdate) actx.db.Model(&sess).Updates(&sessUpdate)