diff --git a/core/utils/ssh/ssh.go b/core/utils/ssh/ssh.go index 9471d8250..26d5278b3 100644 --- a/core/utils/ssh/ssh.go +++ b/core/utils/ssh/ssh.go @@ -2,6 +2,7 @@ package ssh import ( "fmt" + "net" "strings" "time" @@ -51,7 +52,7 @@ func NewClient(c ConnInfo) (*SSHClient, error) { if strings.Contains(c.Addr, ":") { proto = "tcp6" } - client, err := gossh.Dial(proto, addr, config) + client, err := DialWithTimeout(proto, addr, config) if nil != err { return nil, err } @@ -207,3 +208,20 @@ func (c *SSHClient) RunWithStreamOutput(command string, outputCallback func(stri return err } + +func DialWithTimeout(network, addr string, config *gossh.ClientConfig) (*gossh.Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + _ = conn.SetDeadline(time.Now().Add(config.Timeout)) + c, chans, reqs, err := gossh.NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + if err := conn.SetDeadline(time.Time{}); err != nil { + conn.Close() + return nil, fmt.Errorf("clear deadline failed: %v", err) + } + return gossh.NewClient(c, chans, reqs), nil +}