diff --git a/agent/utils/websocket/client.go b/agent/utils/websocket/client.go index 368eeeee9..e9fdf9507 100644 --- a/agent/utils/websocket/client.go +++ b/agent/utils/websocket/client.go @@ -1,25 +1,31 @@ package websocket import ( + "sync/atomic" + "github.com/gorilla/websocket" ) +const MaxMessageQuenue = 32 + type Client struct { ID string Socket *websocket.Conn Msg chan []byte + closed atomic.Bool } func NewWsClient(ID string, socket *websocket.Conn) *Client { return &Client{ ID: ID, Socket: socket, - Msg: make(chan []byte, 100), + Msg: make(chan []byte, MaxMessageQuenue), } } func (c *Client) Read() { defer func() { + c.closed.Store(true) close(c.Msg) }() for { @@ -32,9 +38,7 @@ func (c *Client) Read() { } func (c *Client) Write() { - defer func() { - c.Socket.Close() - }() + defer c.Socket.Close() for { message, ok := <-c.Msg if !ok { @@ -43,3 +47,13 @@ func (c *Client) Write() { _ = c.Socket.WriteMessage(websocket.TextMessage, message) } } + +func (c *Client) Send(res []byte) { + if c.closed.Load() { + return + } + select { + case c.Msg <- res: + default: + } +} diff --git a/agent/utils/websocket/process_data.go b/agent/utils/websocket/process_data.go index 7ac0b177a..eddf5518e 100644 --- a/agent/utils/websocket/process_data.go +++ b/agent/utils/websocket/process_data.go @@ -4,17 +4,20 @@ import ( "context" "encoding/json" "fmt" + "os" "strings" "time" - "github.com/1Panel-dev/1Panel/agent/utils/common" "github.com/1Panel-dev/1Panel/agent/global" + "github.com/1Panel-dev/1Panel/agent/utils/common" "github.com/1Panel-dev/1Panel/agent/utils/files" "github.com/shirou/gopsutil/v4/host" "github.com/shirou/gopsutil/v4/net" "github.com/shirou/gopsutil/v4/process" ) +const defaultTimeout = 10 * time.Second + type WsInput struct { Type string `json:"type"` DownloadProgress @@ -113,25 +116,25 @@ func ProcessData(c *Client, inputMsg []byte) { if err != nil { return } - c.Msg <- res + c.Send(res) case "ps": res, err := getProcessData(wsInput.PsProcessConfig) if err != nil { return } - c.Msg <- res + c.Send(res) case "ssh": res, err := getSSHSessions(wsInput.SSHSessionConfig) if err != nil { return } - c.Msg <- res + c.Send(res) case "net": res, err := getNetConnections(wsInput.NetConfig) if err != nil { return } - c.Msg <- res + c.Send(res) } } @@ -204,7 +207,8 @@ func handleProcessData(proc *process.Process, processConfig *PsProcessConfig, pi } func getProcessData(processConfig PsProcessConfig) (res []byte, err error) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() processes, err := process.ProcessesWithContext(ctx) if err != nil { @@ -243,7 +247,10 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) { users []host.UserStat processes []*process.Process ) - users, err = host.Users() + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + users, err = host.UsersWithContext(ctx) if err != nil { res, err = json.Marshal(result) return @@ -268,8 +275,9 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) { return } - processes, err = process.Processes() + processes, err = process.ProcessesWithContext(ctx) if err != nil { + res, err = json.Marshal(result) return } @@ -312,42 +320,92 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) { return } -var netTypes = [...]string{"tcp", "udp"} - func getNetConnections(config NetConfig) (res []byte, err error) { - var ( - result []ProcessConnect - proc *process.Process - ) - for _, netType := range netTypes { - connections, _ := net.Connections(netType) - if err == nil { - for _, conn := range connections { - if config.ProcessID > 0 && config.ProcessID != conn.Pid { - continue - } - proc, err = process.NewProcess(conn.Pid) - if err == nil { - name, _ := proc.Name() - if name != "" && config.ProcessName != "" && !strings.Contains(name, config.ProcessName) { - continue - } - if config.Port > 0 && config.Port != conn.Laddr.Port && config.Port != conn.Raddr.Port { - continue - } - result = append(result, ProcessConnect{ - Type: netType, - Status: conn.Status, - Laddr: conn.Laddr, - Raddr: conn.Raddr, - PID: conn.Pid, - Name: name, - }) - } + result := make([]ProcessConnect, 0, 1024) + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + connections, err := net.ConnectionsMaxWithContext(ctx, "all", 32768) + if err != nil { + res, _ = json.Marshal(result) + return + } + + pidConnectionsMap := make(map[int32][]net.ConnectionStat, 256) + pidNameMap := make(map[int32]string, 256) + + for _, conn := range connections { + if conn.Family != 2 && conn.Family != 10 { + continue + } + + if conn.Pid == 0 { + continue + } + + if config.ProcessID > 0 && conn.Pid != config.ProcessID { + continue + } + + if config.Port > 0 && conn.Laddr.Port != config.Port && conn.Raddr.Port != config.Port { + continue + } + + if _, exists := pidNameMap[conn.Pid]; !exists { + pName, _ := getProcessNameWithContext(ctx, conn.Pid) + if pName == "" { + pName = "" } + pidNameMap[conn.Pid] = pName + } + + pidConnectionsMap[conn.Pid] = append(pidConnectionsMap[conn.Pid], conn) + } + + for pid, connections := range pidConnectionsMap { + pName := pidNameMap[pid] + if config.ProcessName != "" && !strings.Contains(pName, config.ProcessName) { + continue + } + for _, conn := range connections { + result = append(result, ProcessConnect{ + Type: getConnectionType(conn.Type, conn.Family), + Status: conn.Status, + Laddr: conn.Laddr, + Raddr: conn.Raddr, + PID: conn.Pid, + Name: pName, + }) } } + res, err = json.Marshal(result) return } + +func getProcessNameWithContext(ctx context.Context, pid int32) (string, error) { + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid)) + if err == nil && len(data) > 0 { + return strings.TrimSpace(string(data)), nil + } + p, err := process.NewProcessWithContext(ctx, pid) + if err != nil { + return "", err + } + return p.Name() +} + +func getConnectionType(connType uint32, family uint32) string { + switch { + case connType == 1 && family == 2: + return "tcp" + case connType == 1 && family == 10: + return "tcp6" + case connType == 2 && family == 2: + return "udp" + case connType == 2 && family == 10: + return "udp6" + default: + return "unknown" + } +}