mirror of
https://github.com/moul/sshportal.git
synced 2024-09-20 15:06:07 +08:00
Merge pull request #183 from jrrdev/exit_fix
This commit is contained in:
commit
7f3ea431a1
|
@ -1,7 +1,6 @@
|
||||||
package bastion // import "moul.io/sshportal/pkg/bastion"
|
package bastion // import "moul.io/sshportal/pkg/bastion"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
@ -124,6 +123,7 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errch := make(chan error, 1)
|
errch := make(chan error, 1)
|
||||||
|
quit := make(chan string, 1)
|
||||||
channeltype := newChan.ChannelType()
|
channeltype := newChan.ChannelType()
|
||||||
|
|
||||||
filename := strings.Join([]string{logsLocation, "/", user, "-", username, "-", channeltype, "-", fmt.Sprint(sessionID), "-", time.Now().Format(time.RFC3339)}, "") // get user
|
filename := strings.Join([]string{logsLocation, "/", user, "-", username, "-", channeltype, "-", fmt.Sprint(sessionID), "-", time.Now().Format(time.RFC3339)}, "") // get user
|
||||||
|
@ -139,15 +139,15 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
|
||||||
log.Printf("Session %v is recorded in %v", channeltype, filename)
|
log.Printf("Session %v is recorded in %v", channeltype, filename)
|
||||||
if channeltype == "session" {
|
if channeltype == "session" {
|
||||||
wrappedlch := logchannel.New(lch, f)
|
wrappedlch := logchannel.New(lch, f)
|
||||||
go func() {
|
go func(quit chan string) {
|
||||||
_, _ = io.Copy(wrappedlch, rch)
|
_, _ = io.Copy(wrappedlch, rch)
|
||||||
errch <- errors.New("lch closed the connection")
|
quit <- "rch"
|
||||||
}()
|
}(quit)
|
||||||
|
|
||||||
go func() {
|
go func(quit chan string) {
|
||||||
_, _ = io.Copy(rch, lch)
|
_, _ = io.Copy(rch, lch)
|
||||||
errch <- errors.New("rch closed the connection")
|
quit <- "lch"
|
||||||
}()
|
}(quit)
|
||||||
}
|
}
|
||||||
if channeltype == "direct-tcpip" {
|
if channeltype == "direct-tcpip" {
|
||||||
d := logTunnelForwardData{}
|
d := logTunnelForwardData{}
|
||||||
|
@ -156,23 +156,19 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
|
||||||
}
|
}
|
||||||
wrappedlch := newLogTunnel(lch, f, d.SourceHost)
|
wrappedlch := newLogTunnel(lch, f, d.SourceHost)
|
||||||
wrappedrch := newLogTunnel(rch, f, d.DestinationHost)
|
wrappedrch := newLogTunnel(rch, f, d.DestinationHost)
|
||||||
go func() {
|
go func(quit chan string) {
|
||||||
_, _ = io.Copy(wrappedlch, rch)
|
_, _ = io.Copy(wrappedlch, rch)
|
||||||
errch <- errors.New("lch closed the connection")
|
quit <- "rch"
|
||||||
}()
|
}(quit)
|
||||||
|
|
||||||
go func() {
|
go func(quit chan string) {
|
||||||
_, _ = io.Copy(wrappedrch, lch)
|
_, _ = io.Copy(wrappedrch, lch)
|
||||||
errch <- errors.New("rch closed the connection")
|
quit <- "lch"
|
||||||
}()
|
}(quit)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
go func(quit chan string) {
|
||||||
select {
|
for req := range lreqs {
|
||||||
case req := <-lreqs: // forward ssh requests from local to remote
|
|
||||||
if req == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
|
b, err := rch.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||||
if req.Type == "exec" {
|
if req.Type == "exec" {
|
||||||
wrappedlch := logchannel.New(lch, f)
|
wrappedlch := logchannel.New(lch, f)
|
||||||
|
@ -183,24 +179,58 @@ func pipe(lreqs, rreqs <-chan *gossh.Request, lch, rch gossh.Channel, logsLocati
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
errch <- err
|
||||||
}
|
}
|
||||||
if err2 := req.Reply(b, nil); err2 != nil {
|
if err2 := req.Reply(b, nil); err2 != nil {
|
||||||
return err2
|
errch <- err2
|
||||||
}
|
|
||||||
case req := <-rreqs: // forward ssh requests from remote to local
|
|
||||||
if req == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
quit <- "lreqs"
|
||||||
|
}(quit)
|
||||||
|
|
||||||
|
go func(quit chan string) {
|
||||||
|
for req := range rreqs {
|
||||||
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
|
b, err := lch.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
errch <- err
|
||||||
}
|
}
|
||||||
if err2 := req.Reply(b, nil); err2 != nil {
|
if err2 := req.Reply(b, nil); err2 != nil {
|
||||||
return err2
|
errch <- err2
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
quit <- "rreqs"
|
||||||
|
}(quit)
|
||||||
|
|
||||||
|
lchEOF, rchEOF, lchClosed, rchClosed := false, false, false, false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
case err := <-errch:
|
case err := <-errch:
|
||||||
return err
|
return err
|
||||||
|
case q := <-quit:
|
||||||
|
switch q {
|
||||||
|
case "lch":
|
||||||
|
lchEOF = true
|
||||||
|
_ = rch.CloseWrite()
|
||||||
|
case "rch":
|
||||||
|
rchEOF = true
|
||||||
|
_ = lch.CloseWrite()
|
||||||
|
case "lreqs":
|
||||||
|
lchClosed = true
|
||||||
|
case "rreqs":
|
||||||
|
rchClosed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if lchEOF && lchClosed && !rchClosed {
|
||||||
|
rch.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if rchEOF && rchClosed && !lchClosed {
|
||||||
|
lch.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if lchEOF && rchEOF && lchClosed && rchClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,8 +160,7 @@ func ChannelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewCh
|
||||||
ErrMsg: fmt.Sprintf("%v", err),
|
ErrMsg: fmt.Sprintf("%v", err),
|
||||||
StoppedAt: &now,
|
StoppedAt: &now,
|
||||||
}
|
}
|
||||||
switch sessUpdate.ErrMsg {
|
if err == nil {
|
||||||
case "lch closed the connection", "rch closed the connection":
|
|
||||||
sessUpdate.ErrMsg = ""
|
sessUpdate.ErrMsg = ""
|
||||||
}
|
}
|
||||||
actx.db.Model(&sess).Updates(&sessUpdate)
|
actx.db.Model(&sess).Updates(&sessUpdate)
|
||||||
|
|
Loading…
Reference in a new issue