perf: improve getNetConnections function and Websocket handling (#11269)

* feat: Enhance WebSocket client functionality and improve data processing

- Reduced message queue size in WebSocket client from 100 to 32.
- Introduced atomic boolean to track client closure state.
- Added SendPayload method to handle message sending with queue management.
- Updated ProcessData function to utilize SendPayload for sending responses.
- Expanded netTypes to include both IPv4 and IPv6 protocols in network connection retrieval.
- Improved net connection processing by using a map for process names, enhancing efficiency.

* feat: Enhance WebSocket client and process data handling

- Added synchronization with sync.Once for safe closure of WebSocket client.
- Updated message queue size to a constant for better maintainability.
- Implemented context timeouts for process data retrieval to prevent blocking.
- Improved network connection handling by utilizing a more efficient method for retrieving connections.
- Introduced a new function to determine connection types based on protocol family.

* feat: Enhance network connection retrieval and process name mapping

- Updated getNetConnections function to improve efficiency by using maps for process names and connections.
- Introduced a new helper function to retrieve process names from the filesystem or process context.
- Enhanced filtering logic for network connections based on process ID, name, and port.
- Increased initial capacity for connection results to optimize performance.

* refactor: Rename SendPayload method to Send in WebSocket client

- Updated the SendPayload method to be more succinctly named Send for clarity.
- Ensured the method continues to handle message sending while maintaining existing functionality.

* refactor: Update ProcessData and getNetConnections for improved clarity and efficiency

- Replaced SendPayload method calls with Send for consistency in WebSocket message handling.
- Enhanced getNetConnections function by refining process name retrieval and filtering logic.
- Improved error handling in getProcessNameWithContext for better robustness.

* refactor: Simplify WebSocket client closure and reading logic

- Removed unnecessary synchronization for closing the WebSocket client.
- Updated the Read method to handle message reading directly without a separate Close method.
- Ensured the Socket is closed properly after reading messages to prevent resource leaks.
This commit is contained in:
KOMATA 2025-12-09 17:30:12 +08:00 committed by GitHub
parent d1c2a69820
commit 38985671c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 115 additions and 43 deletions

View file

@ -1,25 +1,31 @@
package websocket package websocket
import ( import (
"sync/atomic"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
const MaxMessageQuenue = 32
type Client struct { type Client struct {
ID string ID string
Socket *websocket.Conn Socket *websocket.Conn
Msg chan []byte Msg chan []byte
closed atomic.Bool
} }
func NewWsClient(ID string, socket *websocket.Conn) *Client { func NewWsClient(ID string, socket *websocket.Conn) *Client {
return &Client{ return &Client{
ID: ID, ID: ID,
Socket: socket, Socket: socket,
Msg: make(chan []byte, 100), Msg: make(chan []byte, MaxMessageQuenue),
} }
} }
func (c *Client) Read() { func (c *Client) Read() {
defer func() { defer func() {
c.closed.Store(true)
close(c.Msg) close(c.Msg)
}() }()
for { for {
@ -32,9 +38,7 @@ func (c *Client) Read() {
} }
func (c *Client) Write() { func (c *Client) Write() {
defer func() { defer c.Socket.Close()
c.Socket.Close()
}()
for { for {
message, ok := <-c.Msg message, ok := <-c.Msg
if !ok { if !ok {
@ -43,3 +47,13 @@ func (c *Client) Write() {
_ = c.Socket.WriteMessage(websocket.TextMessage, message) _ = c.Socket.WriteMessage(websocket.TextMessage, message)
} }
} }
func (c *Client) Send(res []byte) {
if c.closed.Load() {
return
}
select {
case c.Msg <- res:
default:
}
}

View file

@ -4,17 +4,20 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"strings" "strings"
"time" "time"
"github.com/1Panel-dev/1Panel/agent/utils/common"
"github.com/1Panel-dev/1Panel/agent/global" "github.com/1Panel-dev/1Panel/agent/global"
"github.com/1Panel-dev/1Panel/agent/utils/common"
"github.com/1Panel-dev/1Panel/agent/utils/files" "github.com/1Panel-dev/1Panel/agent/utils/files"
"github.com/shirou/gopsutil/v4/host" "github.com/shirou/gopsutil/v4/host"
"github.com/shirou/gopsutil/v4/net" "github.com/shirou/gopsutil/v4/net"
"github.com/shirou/gopsutil/v4/process" "github.com/shirou/gopsutil/v4/process"
) )
const defaultTimeout = 10 * time.Second
type WsInput struct { type WsInput struct {
Type string `json:"type"` Type string `json:"type"`
DownloadProgress DownloadProgress
@ -113,25 +116,25 @@ func ProcessData(c *Client, inputMsg []byte) {
if err != nil { if err != nil {
return return
} }
c.Msg <- res c.Send(res)
case "ps": case "ps":
res, err := getProcessData(wsInput.PsProcessConfig) res, err := getProcessData(wsInput.PsProcessConfig)
if err != nil { if err != nil {
return return
} }
c.Msg <- res c.Send(res)
case "ssh": case "ssh":
res, err := getSSHSessions(wsInput.SSHSessionConfig) res, err := getSSHSessions(wsInput.SSHSessionConfig)
if err != nil { if err != nil {
return return
} }
c.Msg <- res c.Send(res)
case "net": case "net":
res, err := getNetConnections(wsInput.NetConfig) res, err := getNetConnections(wsInput.NetConfig)
if err != nil { if err != nil {
return return
} }
c.Msg <- res c.Send(res)
} }
} }
@ -204,7 +207,8 @@ func handleProcessData(proc *process.Process, processConfig *PsProcessConfig, pi
} }
func getProcessData(processConfig PsProcessConfig) (res []byte, err error) { func getProcessData(processConfig PsProcessConfig) (res []byte, err error) {
ctx := context.Background() ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
processes, err := process.ProcessesWithContext(ctx) processes, err := process.ProcessesWithContext(ctx)
if err != nil { if err != nil {
@ -243,7 +247,10 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) {
users []host.UserStat users []host.UserStat
processes []*process.Process processes []*process.Process
) )
users, err = host.Users() ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
users, err = host.UsersWithContext(ctx)
if err != nil { if err != nil {
res, err = json.Marshal(result) res, err = json.Marshal(result)
return return
@ -268,8 +275,9 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) {
return return
} }
processes, err = process.Processes() processes, err = process.ProcessesWithContext(ctx)
if err != nil { if err != nil {
res, err = json.Marshal(result)
return return
} }
@ -312,42 +320,92 @@ func getSSHSessions(config SSHSessionConfig) (res []byte, err error) {
return return
} }
var netTypes = [...]string{"tcp", "udp"}
func getNetConnections(config NetConfig) (res []byte, err error) { func getNetConnections(config NetConfig) (res []byte, err error) {
var ( result := make([]ProcessConnect, 0, 1024)
result []ProcessConnect ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
proc *process.Process defer cancel()
)
for _, netType := range netTypes {
connections, _ := net.Connections(netType)
if err == nil {
for _, conn := range connections {
if config.ProcessID > 0 && config.ProcessID != conn.Pid {
continue
}
proc, err = process.NewProcess(conn.Pid)
if err == nil {
name, _ := proc.Name()
if name != "" && config.ProcessName != "" && !strings.Contains(name, config.ProcessName) {
continue
}
if config.Port > 0 && config.Port != conn.Laddr.Port && config.Port != conn.Raddr.Port {
continue
}
result = append(result, ProcessConnect{
Type: netType,
Status: conn.Status,
Laddr: conn.Laddr,
Raddr: conn.Raddr,
PID: conn.Pid,
Name: name,
})
}
connections, err := net.ConnectionsMaxWithContext(ctx, "all", 32768)
if err != nil {
res, _ = json.Marshal(result)
return
}
pidConnectionsMap := make(map[int32][]net.ConnectionStat, 256)
pidNameMap := make(map[int32]string, 256)
for _, conn := range connections {
if conn.Family != 2 && conn.Family != 10 {
continue
}
if conn.Pid == 0 {
continue
}
if config.ProcessID > 0 && conn.Pid != config.ProcessID {
continue
}
if config.Port > 0 && conn.Laddr.Port != config.Port && conn.Raddr.Port != config.Port {
continue
}
if _, exists := pidNameMap[conn.Pid]; !exists {
pName, _ := getProcessNameWithContext(ctx, conn.Pid)
if pName == "" {
pName = "<UNKNOWN>"
} }
pidNameMap[conn.Pid] = pName
}
pidConnectionsMap[conn.Pid] = append(pidConnectionsMap[conn.Pid], conn)
}
for pid, connections := range pidConnectionsMap {
pName := pidNameMap[pid]
if config.ProcessName != "" && !strings.Contains(pName, config.ProcessName) {
continue
}
for _, conn := range connections {
result = append(result, ProcessConnect{
Type: getConnectionType(conn.Type, conn.Family),
Status: conn.Status,
Laddr: conn.Laddr,
Raddr: conn.Raddr,
PID: conn.Pid,
Name: pName,
})
} }
} }
res, err = json.Marshal(result) res, err = json.Marshal(result)
return return
} }
func getProcessNameWithContext(ctx context.Context, pid int32) (string, error) {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/comm", pid))
if err == nil && len(data) > 0 {
return strings.TrimSpace(string(data)), nil
}
p, err := process.NewProcessWithContext(ctx, pid)
if err != nil {
return "", err
}
return p.Name()
}
func getConnectionType(connType uint32, family uint32) string {
switch {
case connType == 1 && family == 2:
return "tcp"
case connType == 1 && family == 10:
return "tcp6"
case connType == 2 && family == 2:
return "udp"
case connType == 2 && family == 10:
return "udp6"
default:
return "unknown"
}
}