package client import ( "fmt" "strconv" "strings" "github.com/1Panel-dev/1Panel/agent/buserr" "github.com/1Panel-dev/1Panel/agent/global" "github.com/1Panel-dev/1Panel/agent/utils/cmd" "github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables" "github.com/1Panel-dev/1Panel/agent/utils/re" ) type Iptables struct{} func NewIptables() (*Iptables, error) { return &Iptables{}, nil } func (i *Iptables) Name() string { return "iptables" } func (i *Iptables) Status() (bool, error) { stdout, err := cmd.RunDefaultWithStdoutBashC("iptables -L -n | head -1") if err != nil { return false, err } return strings.Contains(stdout, "Chain"), nil } func (i *Iptables) Start() error { return nil } func (i *Iptables) Stop() error { return nil } func (i *Iptables) Restart() error { return nil } func (i *Iptables) Reload() error { return nil } func (i *Iptables) Version() (string, error) { stdout, err := cmd.RunDefaultWithStdoutBashC("iptables --version") if err != nil { return "", fmt.Errorf("failed to get iptables version: %w", err) } parts := strings.Fields(stdout) if len(parts) >= 2 { return strings.TrimPrefix(parts[1], "v"), nil } return strings.TrimSpace(stdout), nil } func (i *Iptables) ListPort() ([]FireInfo, error) { stdout, err := iptables.RunWithStd(iptables.FilterTab, fmt.Sprintf("-S %s", iptables.Chain1PanelBasic)) if err != nil { return nil, fmt.Errorf("failed to list 1PANEL_BASIC rules: %w", err) } var datas []FireInfo lines := strings.Split(stdout, "\n") chainPortRegex := re.GetRegex(iptables.Chian1PanelBasicPortPattern) for _, line := range lines { line = strings.TrimSpace(line) if !strings.HasPrefix(line, fmt.Sprintf("-A %s", iptables.Chain1PanelBasic)) { continue } if matches := chainPortRegex.FindStringSubmatch(line); len(matches) == 5 { address := matches[1] protocol := matches[2] port := strings.ReplaceAll(matches[3], ":", "-") action := strings.ToLower(matches[4]) strategy := "accept" if action == "drop" || action == "reject" { strategy = "drop" } datas = append(datas, FireInfo{ Address: address, Protocol: protocol, Port: port, Strategy: strategy, Family: "ipv4", }) } } return datas, nil } func (i *Iptables) ListAddress() ([]FireInfo, error) { stdout, err := iptables.RunWithStd(iptables.FilterTab, fmt.Sprintf("-S %s", iptables.Chain1PanelBasic)) if err != nil { return nil, fmt.Errorf("failed to list 1PANEL_BASIC address rules: %w", err) } var datas []FireInfo lines := strings.Split(stdout, "\n") addressMap := make(map[string]FireInfo) chainAddressRegex := re.GetRegex(iptables.Chain1PanelBasicAddressPattern) for _, line := range lines { line = strings.TrimSpace(line) if !strings.HasPrefix(line, fmt.Sprintf("-A %s", iptables.Chain1PanelBasic)) { continue } if matches := chainAddressRegex.FindStringSubmatch(line); len(matches) >= 4 { address := matches[1] if address == "" { address = matches[2] } if address == "" { continue } action := strings.ToLower(matches[3]) strategy := "accept" if action == "drop" || action == "reject" { strategy = "drop" } if _, exists := addressMap[address]; !exists { addressMap[address] = FireInfo{ Address: address, Strategy: strategy, Family: "ipv4", } } } } for _, info := range addressMap { datas = append(datas, info) } return datas, nil } func (i *Iptables) Port(port FireInfo, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } portSpec, err := normalizePortSpec(port.Port) if err != nil { return err } protocol := port.Protocol if protocol == "" { protocol = "tcp" } action := "ACCEPT" if port.Strategy == "drop" { action = "DROP" } ruleArgs := []string{fmt.Sprintf("-p %s", protocol)} if protocol == "tcp" || protocol == "udp" { ruleArgs = append(ruleArgs, fmt.Sprintf("-m %s", protocol)) } ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSpec), fmt.Sprintf("-j %s", action)) ruleSpec := strings.Join(ruleArgs, " ") if operation == "add" { if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasic, ruleSpec); err != nil { return err } } else { if err := iptables.DeleteRule(iptables.FilterTab, iptables.Chain1PanelBasic, ruleSpec); err != nil { return err } } if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelBasic, iptables.BasicFileName); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) } return nil } func (i *Iptables) RichRules(rule FireInfo, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } address := strings.TrimSpace(rule.Address) if strings.EqualFold(address, "Anywhere") { address = "" } action := "ACCEPT" if rule.Strategy == "drop" { action = "DROP" } var ruleArgs []string if address != "" { ruleArgs = append(ruleArgs, fmt.Sprintf("-s %s", address)) } protocol := strings.TrimSpace(rule.Protocol) if rule.Port != "" && protocol == "" { protocol = "tcp" } if protocol != "" { ruleArgs = append(ruleArgs, fmt.Sprintf("-p %s", protocol)) } if rule.Port != "" { portSegment, err := normalizePortSpec(rule.Port) if err != nil { return err } if protocol == "" { return fmt.Errorf("protocol is required when specifying a port") } if protocol == "tcp" || protocol == "udp" { ruleArgs = append(ruleArgs, fmt.Sprintf("-m %s", protocol)) } ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSegment)) } ruleArgs = append(ruleArgs, fmt.Sprintf("-j %s", action)) ruleSpec := strings.Join(ruleArgs, " ") if operation == "add" { if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasic, ruleSpec); err != nil { return err } } else { if err := iptables.DeleteRule(iptables.FilterTab, iptables.Chain1PanelBasic, ruleSpec); err != nil { return err } } if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelBasic, iptables.BasicFileName); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) } return nil } func (i *Iptables) PortForward(info Forward, operation string) error { return iptablesPortForward(info, operation) } func (i *Iptables) EnableForward() error { return EnableIptablesForward() } func (i *Iptables) ListForward() ([]FireInfo, error) { return iptablesListForward() } func EnableIptablesForward() error { if err := cmd.RunDefaultBashC("echo 1 > /proc/sys/net/ipv4/ip_forward"); err != nil { return fmt.Errorf("failed to enable IP forwarding: %w", err) } _ = cmd.RunDefaultBashC("grep -q '^net.ipv4.ip_forward' /etc/sysctl.conf || echo 'net.ipv4.ip_forward = 1' >> /etc/sysctl.conf") _ = cmd.RunDefaultBashC("sysctl -p") if err := iptables.AddChainWithAppend(iptables.NatTab, "PREROUTING", iptables.Chain1PanelPreRouting); err != nil { return err } if err := iptables.AddChainWithAppend(iptables.NatTab, "POSTROUTING", iptables.Chain1PanelPostRouting); err != nil { return err } if err := iptables.AddChainWithAppend(iptables.FilterTab, "FORWARD", iptables.Chain1PanelForward); err != nil { return err } return nil } func iptablesPortForward(info Forward, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } if info.Protocol == "" || info.Port == "" || info.TargetPort == "" { return fmt.Errorf("protocol, port, and target port are required") } if operation == "add" { if err := iptables.AddForward(info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface, true); err != nil { return err } forwardPersistence() } natList, err := iptables.ListForward(iptables.Chain1PanelPreRouting) if err != nil { return fmt.Errorf("failed to list NAT rules: %w", err) } for _, nat := range natList { if nat.Protocol == info.Protocol && strings.TrimPrefix(nat.SrcPort, ":") == info.Port && strings.TrimPrefix(nat.DestPort, ":") == info.TargetPort { targetIP := info.TargetIP if targetIP == "" { targetIP = "127.0.0.1" } if err := iptables.DeleteForward(nat.Num, info.Protocol, info.Port, targetIP, info.TargetPort, info.Interface); err != nil { return err } forwardPersistence() } } return fmt.Errorf("forward rule not found") } func forwardPersistence() { if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelForward, iptables.ForwardFileName); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelForward, err) } if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPreRouting, iptables.ForwardFileName1); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPreRouting, err) } if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPostRouting, iptables.ForwardFileName2); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPostRouting, err) } } func iptablesListForward() ([]FireInfo, error) { natList, err := iptables.ListForward(iptables.Chain1PanelPreRouting) if err != nil { return nil, fmt.Errorf("failed to list NAT rules: %w", err) } var datas []FireInfo for _, nat := range natList { datas = append(datas, FireInfo{ Num: nat.Num, Protocol: nat.Protocol, Port: strings.TrimPrefix(nat.SrcPort, ":"), TargetIP: nat.Destination, TargetPort: strings.TrimPrefix(nat.DestPort, ":"), Interface: nat.InIface, }) } return datas, nil } func parsePort(portStr string) (int, error) { port, err := strconv.Atoi(portStr) if err != nil { return 0, fmt.Errorf("invalid port number: %s", portStr) } if port < 1 || port > 65535 { return 0, fmt.Errorf("port out of range: %d", port) } return port, nil } func normalizePortSpec(port string) (string, error) { value := strings.TrimSpace(port) if value == "" { return "", fmt.Errorf("port is required") } separator := "" if strings.Contains(value, "-") { separator = "-" } else if strings.Contains(value, ":") { separator = ":" } if separator != "" { parts := strings.Split(value, separator) if len(parts) != 2 { return "", fmt.Errorf("invalid port range: %s", port) } start, err := parsePort(strings.TrimSpace(parts[0])) if err != nil { return "", err } end, err := parsePort(strings.TrimSpace(parts[1])) if err != nil { return "", err } if start > end { return "", fmt.Errorf("invalid port range: %d-%d", start, end) } return fmt.Sprintf("%d:%d", start, end), nil } single, err := parsePort(value) if err != nil { return "", err } return fmt.Sprintf("%d", single), nil }