1Panel/agent/utils/firewall/client/iptables.go

359 lines
9.3 KiB
Go

package client
import (
"fmt"
"strconv"
"strings"
"github.com/1Panel-dev/1Panel/agent/buserr"
"github.com/1Panel-dev/1Panel/agent/global"
"github.com/1Panel-dev/1Panel/agent/utils/cmd"
"github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables"
)
type Iptables struct{}
func NewIptables() (*Iptables, error) {
return &Iptables{}, nil
}
func (i *Iptables) Name() string {
return "iptables"
}
func (i *Iptables) Status() (bool, error) {
stdout, err := cmd.RunDefaultWithStdoutBashC("iptables -L -n | head -1")
if err != nil {
return false, err
}
return strings.Contains(stdout, "Chain"), nil
}
func (i *Iptables) Start() error {
return nil
}
func (i *Iptables) Stop() error {
return nil
}
func (i *Iptables) Restart() error {
return nil
}
func (i *Iptables) Reload() error {
return nil
}
func (i *Iptables) Version() (string, error) {
stdout, err := cmd.RunDefaultWithStdoutBashC("iptables --version")
if err != nil {
return "", fmt.Errorf("failed to get iptables version: %w", err)
}
parts := strings.Fields(stdout)
if len(parts) >= 2 {
return strings.TrimPrefix(parts[1], "v"), nil
}
return strings.TrimSpace(stdout), nil
}
func (i *Iptables) ListPort() ([]FireInfo, error) {
var datas []FireInfo
basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic)
if err != nil {
return nil, err
}
beforeRules, _ := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasicBefore)
basicRules = append(basicRules, beforeRules...)
for _, item := range basicRules {
if len(item.DstPort) == 0 {
continue
}
if item.Strategy == "drop" || item.Strategy == "reject" {
item.Strategy = "drop"
}
datas = append(datas, FireInfo{
Chain: item.Chain,
Address: item.SrcIP,
Protocol: item.Protocol,
Port: item.DstPort,
Strategy: item.Strategy,
Family: "ipv4",
})
}
return datas, nil
}
func (i *Iptables) ListAddress() ([]FireInfo, error) {
var datas []FireInfo
basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic)
if err != nil {
return nil, err
}
for _, item := range basicRules {
if len(item.DstPort) != 0 || len(item.SrcPort) != 0 {
continue
}
if item.Strategy == "drop" || item.Strategy == "reject" {
item.Strategy = "drop"
}
datas = append(datas, FireInfo{
Address: item.SrcIP,
Strategy: item.Strategy,
Family: "ipv4",
})
}
return datas, nil
}
func (i *Iptables) Port(port FireInfo, operation string) error {
if operation != "add" && operation != "remove" {
return buserr.New("ErrCmdIllegal")
}
if len(port.Chain) == 0 {
port.Chain = iptables.Chain1PanelBasic
}
portSpec, err := normalizePortSpec(port.Port)
if err != nil {
return err
}
protocol := port.Protocol
if protocol == "" {
protocol = "tcp"
}
action := "ACCEPT"
if port.Strategy == "drop" {
action = "DROP"
}
ruleArgs := []string{fmt.Sprintf("-p %s", protocol)}
ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSpec), fmt.Sprintf("-j %s", action))
ruleSpec := strings.Join(ruleArgs, " ")
if operation == "add" {
if err := iptables.AddRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil {
return err
}
} else {
if err := iptables.DeleteRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil {
return err
}
}
name := iptables.BasicFileName
if port.Chain == iptables.Chain1PanelBasicBefore {
name = iptables.BasicBeforeFileName
}
if port.Chain == iptables.Chain1PanelBasic {
if err := iptables.SaveRulesToFile(iptables.FilterTab, port.Chain, name); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
}
}
return nil
}
func (i *Iptables) RichRules(rule FireInfo, operation string) error {
if operation != "add" && operation != "remove" {
return buserr.New("ErrCmdIllegal")
}
if len(rule.Chain) == 0 {
rule.Chain = iptables.Chain1PanelBasic
}
address := strings.TrimSpace(rule.Address)
if strings.EqualFold(address, "Anywhere") {
address = ""
}
action := "ACCEPT"
if rule.Strategy == "drop" {
action = "DROP"
}
var ruleArgs []string
if address != "" {
ruleArgs = append(ruleArgs, fmt.Sprintf("-s %s", address))
}
protocol := strings.TrimSpace(rule.Protocol)
if rule.Port != "" && protocol == "" {
protocol = "tcp"
}
if protocol != "" {
ruleArgs = append(ruleArgs, fmt.Sprintf("-p %s", protocol))
}
if rule.Port != "" {
portSegment, err := normalizePortSpec(rule.Port)
if err != nil {
return err
}
if protocol == "" {
return fmt.Errorf("protocol is required when specifying a port")
}
ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSegment))
}
ruleArgs = append(ruleArgs, fmt.Sprintf("-j %s", action))
ruleSpec := strings.Join(ruleArgs, " ")
if operation == "add" {
if err := iptables.AddRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil {
return err
}
} else {
if err := iptables.DeleteRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil {
return err
}
}
name := iptables.BasicFileName
if rule.Chain == iptables.Chain1PanelBasicBefore {
name = iptables.BasicBeforeFileName
}
if rule.Chain == iptables.Chain1PanelBasic {
if err := iptables.SaveRulesToFile(iptables.FilterTab, rule.Chain, name); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
}
}
return nil
}
func (i *Iptables) PortForward(info Forward, operation string) error {
return iptablesPortForward(info, operation)
}
func (i *Iptables) EnableForward() error {
return EnableIptablesForward()
}
func (i *Iptables) ListForward() ([]FireInfo, error) {
return iptablesListForward()
}
func EnableIptablesForward() error {
if err := cmd.RunDefaultBashC("echo 1 > /proc/sys/net/ipv4/ip_forward"); err != nil {
return fmt.Errorf("failed to enable IP forwarding: %w", err)
}
_ = cmd.RunDefaultBashC("grep -q '^net.ipv4.ip_forward' /etc/sysctl.conf || echo 'net.ipv4.ip_forward = 1' >> /etc/sysctl.conf")
_ = cmd.RunDefaultBashC("sysctl -p")
if err := iptables.AddChainWithAppend(iptables.NatTab, "PREROUTING", iptables.Chain1PanelPreRouting); err != nil {
return err
}
if err := iptables.AddChainWithAppend(iptables.NatTab, "POSTROUTING", iptables.Chain1PanelPostRouting); err != nil {
return err
}
if err := iptables.AddChainWithAppend(iptables.FilterTab, "FORWARD", iptables.Chain1PanelForward); err != nil {
return err
}
return nil
}
func iptablesPortForward(info Forward, operation string) error {
if operation != "add" && operation != "remove" {
return buserr.New("ErrCmdIllegal")
}
if info.Protocol == "" || info.Port == "" || info.TargetPort == "" {
return fmt.Errorf("protocol, port, and target port are required")
}
if operation == "add" {
if err := iptables.AddForward(info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface, true); err != nil {
return err
}
} else {
if err := iptables.DeleteForward(info.Num, info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface); err != nil {
return err
}
}
forwardPersistence()
return nil
}
func forwardPersistence() {
if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelForward, iptables.ForwardFileName); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelForward, err)
}
if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPreRouting, iptables.ForwardFileName1); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPreRouting, err)
}
if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPostRouting, iptables.ForwardFileName2); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPostRouting, err)
}
}
func iptablesListForward() ([]FireInfo, error) {
natList, err := iptables.ListForward(iptables.Chain1PanelPreRouting)
if err != nil {
return nil, fmt.Errorf("failed to list NAT rules: %w", err)
}
var datas []FireInfo
for _, nat := range natList {
datas = append(datas, FireInfo{
Num: nat.Num,
Protocol: nat.Protocol,
Port: strings.TrimPrefix(nat.SrcPort, ":"),
TargetIP: nat.Destination,
TargetPort: strings.TrimPrefix(nat.DestPort, ":"),
Interface: nat.InIface,
})
}
return datas, nil
}
func parsePort(portStr string) (int, error) {
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port < 1 || port > 65535 {
return 0, fmt.Errorf("port out of range: %d", port)
}
return port, nil
}
func normalizePortSpec(port string) (string, error) {
value := strings.TrimSpace(port)
if value == "" {
return "", fmt.Errorf("port is required")
}
separator := ""
if strings.Contains(value, "-") {
separator = "-"
} else if strings.Contains(value, ":") {
separator = ":"
}
if separator != "" {
parts := strings.Split(value, separator)
if len(parts) != 2 {
return "", fmt.Errorf("invalid port range: %s", port)
}
start, err := parsePort(strings.TrimSpace(parts[0]))
if err != nil {
return "", err
}
end, err := parsePort(strings.TrimSpace(parts[1]))
if err != nil {
return "", err
}
if start > end {
return "", fmt.Errorf("invalid port range: %d-%d", start, end)
}
return fmt.Sprintf("%d:%d", start, end), nil
}
single, err := parsePort(value)
if err != nil {
return "", err
}
return fmt.Sprintf("%d", single), nil
}