package client import ( "fmt" "strconv" "strings" "github.com/1Panel-dev/1Panel/agent/buserr" "github.com/1Panel-dev/1Panel/agent/global" "github.com/1Panel-dev/1Panel/agent/utils/cmd" "github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables" ) type Iptables struct{} func NewIptables() (*Iptables, error) { return &Iptables{}, nil } func (i *Iptables) Name() string { return "iptables" } func (i *Iptables) Status() (bool, error) { stdout, err := cmd.RunDefaultWithStdoutBashC("iptables -L -n | head -1") if err != nil { return false, err } return strings.Contains(stdout, "Chain"), nil } func (i *Iptables) Start() error { return nil } func (i *Iptables) Stop() error { return nil } func (i *Iptables) Restart() error { return nil } func (i *Iptables) Reload() error { return nil } func (i *Iptables) Version() (string, error) { stdout, err := cmd.RunDefaultWithStdoutBashC("iptables --version") if err != nil { return "", fmt.Errorf("failed to get iptables version: %w", err) } parts := strings.Fields(stdout) if len(parts) >= 2 { return strings.TrimPrefix(parts[1], "v"), nil } return strings.TrimSpace(stdout), nil } func (i *Iptables) ListPort() ([]FireInfo, error) { var datas []FireInfo basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic) if err != nil { return nil, err } beforeRules, _ := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasicBefore) basicRules = append(basicRules, beforeRules...) for _, item := range basicRules { if len(item.DstPort) == 0 { continue } if item.Strategy == "drop" || item.Strategy == "reject" { item.Strategy = "drop" } datas = append(datas, FireInfo{ Chain: item.Chain, Address: item.SrcIP, Protocol: item.Protocol, Port: item.DstPort, Strategy: item.Strategy, Family: "ipv4", }) } return datas, nil } func (i *Iptables) ListAddress() ([]FireInfo, error) { var datas []FireInfo basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic) if err != nil { return nil, err } for _, item := range basicRules { if len(item.DstPort) != 0 || len(item.SrcPort) != 0 { continue } if item.Strategy == "drop" || item.Strategy == "reject" { item.Strategy = "drop" } datas = append(datas, FireInfo{ Address: item.SrcIP, Strategy: item.Strategy, Family: "ipv4", }) } return datas, nil } func (i *Iptables) Port(port FireInfo, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } if len(port.Chain) == 0 { port.Chain = iptables.Chain1PanelBasic } portSpec, err := normalizePortSpec(port.Port) if err != nil { return err } protocol := port.Protocol if protocol == "" { protocol = "tcp" } action := "ACCEPT" if port.Strategy == "drop" { action = "DROP" } ruleArgs := []string{fmt.Sprintf("-p %s", protocol)} ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSpec), fmt.Sprintf("-j %s", action)) ruleSpec := strings.Join(ruleArgs, " ") if operation == "add" { if err := iptables.AddRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil { return err } } else { if err := iptables.DeleteRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil { return err } } name := iptables.BasicFileName if port.Chain == iptables.Chain1PanelBasicBefore { name = iptables.BasicBeforeFileName } if port.Chain == iptables.Chain1PanelBasic { if err := iptables.SaveRulesToFile(iptables.FilterTab, port.Chain, name); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) } } return nil } func (i *Iptables) RichRules(rule FireInfo, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } if len(rule.Chain) == 0 { rule.Chain = iptables.Chain1PanelBasic } address := strings.TrimSpace(rule.Address) if strings.EqualFold(address, "Anywhere") { address = "" } action := "ACCEPT" if rule.Strategy == "drop" { action = "DROP" } var ruleArgs []string if address != "" { ruleArgs = append(ruleArgs, fmt.Sprintf("-s %s", address)) } protocol := strings.TrimSpace(rule.Protocol) if rule.Port != "" && protocol == "" { protocol = "tcp" } if protocol != "" { ruleArgs = append(ruleArgs, fmt.Sprintf("-p %s", protocol)) } if rule.Port != "" { portSegment, err := normalizePortSpec(rule.Port) if err != nil { return err } if protocol == "" { return fmt.Errorf("protocol is required when specifying a port") } ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSegment)) } ruleArgs = append(ruleArgs, fmt.Sprintf("-j %s", action)) ruleSpec := strings.Join(ruleArgs, " ") if operation == "add" { if err := iptables.AddRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil { return err } } else { if err := iptables.DeleteRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil { return err } } name := iptables.BasicFileName if rule.Chain == iptables.Chain1PanelBasicBefore { name = iptables.BasicBeforeFileName } if rule.Chain == iptables.Chain1PanelBasic { if err := iptables.SaveRulesToFile(iptables.FilterTab, rule.Chain, name); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) } } return nil } func (i *Iptables) PortForward(info Forward, operation string) error { return iptablesPortForward(info, operation) } func (i *Iptables) EnableForward() error { return EnableIptablesForward() } func (i *Iptables) ListForward() ([]FireInfo, error) { return iptablesListForward() } func EnableIptablesForward() error { if err := cmd.RunDefaultBashC("echo 1 > /proc/sys/net/ipv4/ip_forward"); err != nil { return fmt.Errorf("failed to enable IP forwarding: %w", err) } _ = cmd.RunDefaultBashC("grep -q '^net.ipv4.ip_forward' /etc/sysctl.conf || echo 'net.ipv4.ip_forward = 1' >> /etc/sysctl.conf") _ = cmd.RunDefaultBashC("sysctl -p") if err := iptables.AddChainWithAppend(iptables.NatTab, "PREROUTING", iptables.Chain1PanelPreRouting); err != nil { return err } if err := iptables.AddChainWithAppend(iptables.NatTab, "POSTROUTING", iptables.Chain1PanelPostRouting); err != nil { return err } if err := iptables.AddChainWithAppend(iptables.FilterTab, "FORWARD", iptables.Chain1PanelForward); err != nil { return err } return nil } func iptablesPortForward(info Forward, operation string) error { if operation != "add" && operation != "remove" { return buserr.New("ErrCmdIllegal") } if info.Protocol == "" || info.Port == "" || info.TargetPort == "" { return fmt.Errorf("protocol, port, and target port are required") } if operation == "add" { if err := iptables.AddForward(info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface, true); err != nil { return err } } else { if err := iptables.DeleteForward(info.Num, info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface); err != nil { return err } } forwardPersistence() return nil } func forwardPersistence() { if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelForward, iptables.ForwardFileName); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelForward, err) } if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPreRouting, iptables.ForwardFileName1); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPreRouting, err) } if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPostRouting, iptables.ForwardFileName2); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPostRouting, err) } } func iptablesListForward() ([]FireInfo, error) { natList, err := iptables.ListForward(iptables.Chain1PanelPreRouting) if err != nil { return nil, fmt.Errorf("failed to list NAT rules: %w", err) } var datas []FireInfo for _, nat := range natList { datas = append(datas, FireInfo{ Num: nat.Num, Protocol: nat.Protocol, Port: strings.TrimPrefix(nat.SrcPort, ":"), TargetIP: nat.Destination, TargetPort: strings.TrimPrefix(nat.DestPort, ":"), Interface: nat.InIface, }) } return datas, nil } func parsePort(portStr string) (int, error) { port, err := strconv.Atoi(portStr) if err != nil { return 0, fmt.Errorf("invalid port number: %s", portStr) } if port < 1 || port > 65535 { return 0, fmt.Errorf("port out of range: %d", port) } return port, nil } func normalizePortSpec(port string) (string, error) { value := strings.TrimSpace(port) if value == "" { return "", fmt.Errorf("port is required") } separator := "" if strings.Contains(value, "-") { separator = "-" } else if strings.Contains(value, ":") { separator = ":" } if separator != "" { parts := strings.Split(value, separator) if len(parts) != 2 { return "", fmt.Errorf("invalid port range: %s", port) } start, err := parsePort(strings.TrimSpace(parts[0])) if err != nil { return "", err } end, err := parsePort(strings.TrimSpace(parts[1])) if err != nil { return "", err } if start > end { return "", fmt.Errorf("invalid port range: %d-%d", start, end) } return fmt.Sprintf("%d:%d", start, end), nil } single, err := parsePort(value) if err != nil { return "", err } return fmt.Sprintf("%d", single), nil }