mirror of
https://github.com/1Panel-dev/1Panel.git
synced 2025-12-17 21:08:25 +08:00
359 lines
9.3 KiB
Go
359 lines
9.3 KiB
Go
package client
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/1Panel-dev/1Panel/agent/buserr"
|
|
"github.com/1Panel-dev/1Panel/agent/global"
|
|
"github.com/1Panel-dev/1Panel/agent/utils/cmd"
|
|
"github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables"
|
|
)
|
|
|
|
type Iptables struct{}
|
|
|
|
func NewIptables() (*Iptables, error) {
|
|
return &Iptables{}, nil
|
|
}
|
|
|
|
func (i *Iptables) Name() string {
|
|
return "iptables"
|
|
}
|
|
|
|
func (i *Iptables) Status() (bool, error) {
|
|
stdout, err := cmd.RunDefaultWithStdoutBashC("iptables -L -n | head -1")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return strings.Contains(stdout, "Chain"), nil
|
|
}
|
|
|
|
func (i *Iptables) Start() error {
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) Restart() error {
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) Reload() error {
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) Version() (string, error) {
|
|
stdout, err := cmd.RunDefaultWithStdoutBashC("iptables --version")
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get iptables version: %w", err)
|
|
}
|
|
parts := strings.Fields(stdout)
|
|
if len(parts) >= 2 {
|
|
return strings.TrimPrefix(parts[1], "v"), nil
|
|
}
|
|
return strings.TrimSpace(stdout), nil
|
|
}
|
|
|
|
func (i *Iptables) ListPort() ([]FireInfo, error) {
|
|
var datas []FireInfo
|
|
basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
beforeRules, _ := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasicBefore)
|
|
basicRules = append(basicRules, beforeRules...)
|
|
for _, item := range basicRules {
|
|
if len(item.DstPort) == 0 {
|
|
continue
|
|
}
|
|
if item.Strategy == "drop" || item.Strategy == "reject" {
|
|
item.Strategy = "drop"
|
|
}
|
|
|
|
datas = append(datas, FireInfo{
|
|
Chain: item.Chain,
|
|
Address: item.SrcIP,
|
|
Protocol: item.Protocol,
|
|
Port: item.DstPort,
|
|
Strategy: item.Strategy,
|
|
Family: "ipv4",
|
|
})
|
|
}
|
|
|
|
return datas, nil
|
|
}
|
|
|
|
func (i *Iptables) ListAddress() ([]FireInfo, error) {
|
|
var datas []FireInfo
|
|
basicRules, err := iptables.ReadFilterRulesByChain(iptables.Chain1PanelBasic)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, item := range basicRules {
|
|
if len(item.DstPort) != 0 || len(item.SrcPort) != 0 {
|
|
continue
|
|
}
|
|
if item.Strategy == "drop" || item.Strategy == "reject" {
|
|
item.Strategy = "drop"
|
|
}
|
|
datas = append(datas, FireInfo{
|
|
Address: item.SrcIP,
|
|
Strategy: item.Strategy,
|
|
Family: "ipv4",
|
|
})
|
|
}
|
|
return datas, nil
|
|
}
|
|
|
|
func (i *Iptables) Port(port FireInfo, operation string) error {
|
|
if operation != "add" && operation != "remove" {
|
|
return buserr.New("ErrCmdIllegal")
|
|
}
|
|
if len(port.Chain) == 0 {
|
|
port.Chain = iptables.Chain1PanelBasic
|
|
}
|
|
|
|
portSpec, err := normalizePortSpec(port.Port)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
protocol := port.Protocol
|
|
if protocol == "" {
|
|
protocol = "tcp"
|
|
}
|
|
|
|
action := "ACCEPT"
|
|
if port.Strategy == "drop" {
|
|
action = "DROP"
|
|
}
|
|
|
|
ruleArgs := []string{fmt.Sprintf("-p %s", protocol)}
|
|
ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSpec), fmt.Sprintf("-j %s", action))
|
|
ruleSpec := strings.Join(ruleArgs, " ")
|
|
if operation == "add" {
|
|
if err := iptables.AddRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := iptables.DeleteRule(iptables.FilterTab, port.Chain, ruleSpec); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
name := iptables.BasicFileName
|
|
if port.Chain == iptables.Chain1PanelBasicBefore {
|
|
name = iptables.BasicBeforeFileName
|
|
}
|
|
if port.Chain == iptables.Chain1PanelBasic {
|
|
if err := iptables.SaveRulesToFile(iptables.FilterTab, port.Chain, name); err != nil {
|
|
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) RichRules(rule FireInfo, operation string) error {
|
|
if operation != "add" && operation != "remove" {
|
|
return buserr.New("ErrCmdIllegal")
|
|
}
|
|
if len(rule.Chain) == 0 {
|
|
rule.Chain = iptables.Chain1PanelBasic
|
|
}
|
|
|
|
address := strings.TrimSpace(rule.Address)
|
|
if strings.EqualFold(address, "Anywhere") {
|
|
address = ""
|
|
}
|
|
|
|
action := "ACCEPT"
|
|
if rule.Strategy == "drop" {
|
|
action = "DROP"
|
|
}
|
|
|
|
var ruleArgs []string
|
|
if address != "" {
|
|
ruleArgs = append(ruleArgs, fmt.Sprintf("-s %s", address))
|
|
}
|
|
|
|
protocol := strings.TrimSpace(rule.Protocol)
|
|
if rule.Port != "" && protocol == "" {
|
|
protocol = "tcp"
|
|
}
|
|
|
|
if protocol != "" {
|
|
ruleArgs = append(ruleArgs, fmt.Sprintf("-p %s", protocol))
|
|
}
|
|
|
|
if rule.Port != "" {
|
|
portSegment, err := normalizePortSpec(rule.Port)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if protocol == "" {
|
|
return fmt.Errorf("protocol is required when specifying a port")
|
|
}
|
|
ruleArgs = append(ruleArgs, fmt.Sprintf("--dport %s", portSegment))
|
|
}
|
|
|
|
ruleArgs = append(ruleArgs, fmt.Sprintf("-j %s", action))
|
|
ruleSpec := strings.Join(ruleArgs, " ")
|
|
if operation == "add" {
|
|
if err := iptables.AddRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := iptables.DeleteRule(iptables.FilterTab, rule.Chain, ruleSpec); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
name := iptables.BasicFileName
|
|
if rule.Chain == iptables.Chain1PanelBasicBefore {
|
|
name = iptables.BasicBeforeFileName
|
|
}
|
|
if rule.Chain == iptables.Chain1PanelBasic {
|
|
if err := iptables.SaveRulesToFile(iptables.FilterTab, rule.Chain, name); err != nil {
|
|
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *Iptables) PortForward(info Forward, operation string) error {
|
|
return iptablesPortForward(info, operation)
|
|
}
|
|
|
|
func (i *Iptables) EnableForward() error {
|
|
return EnableIptablesForward()
|
|
}
|
|
|
|
func (i *Iptables) ListForward() ([]FireInfo, error) {
|
|
return iptablesListForward()
|
|
}
|
|
|
|
func EnableIptablesForward() error {
|
|
if err := cmd.RunDefaultBashC("echo 1 > /proc/sys/net/ipv4/ip_forward"); err != nil {
|
|
return fmt.Errorf("failed to enable IP forwarding: %w", err)
|
|
}
|
|
_ = cmd.RunDefaultBashC("grep -q '^net.ipv4.ip_forward' /etc/sysctl.conf || echo 'net.ipv4.ip_forward = 1' >> /etc/sysctl.conf")
|
|
_ = cmd.RunDefaultBashC("sysctl -p")
|
|
|
|
if err := iptables.AddChainWithAppend(iptables.NatTab, "PREROUTING", iptables.Chain1PanelPreRouting); err != nil {
|
|
return err
|
|
}
|
|
if err := iptables.AddChainWithAppend(iptables.NatTab, "POSTROUTING", iptables.Chain1PanelPostRouting); err != nil {
|
|
return err
|
|
}
|
|
if err := iptables.AddChainWithAppend(iptables.FilterTab, "FORWARD", iptables.Chain1PanelForward); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func iptablesPortForward(info Forward, operation string) error {
|
|
if operation != "add" && operation != "remove" {
|
|
return buserr.New("ErrCmdIllegal")
|
|
}
|
|
if info.Protocol == "" || info.Port == "" || info.TargetPort == "" {
|
|
return fmt.Errorf("protocol, port, and target port are required")
|
|
}
|
|
if operation == "add" {
|
|
if err := iptables.AddForward(info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface, true); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := iptables.DeleteForward(info.Num, info.Protocol, info.Port, info.TargetIP, info.TargetPort, info.Interface); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
forwardPersistence()
|
|
return nil
|
|
}
|
|
|
|
func forwardPersistence() {
|
|
if err := iptables.SaveRulesToFile(iptables.FilterTab, iptables.Chain1PanelForward, iptables.ForwardFileName); err != nil {
|
|
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelForward, err)
|
|
}
|
|
if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPreRouting, iptables.ForwardFileName1); err != nil {
|
|
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPreRouting, err)
|
|
}
|
|
if err := iptables.SaveRulesToFile(iptables.NatTab, iptables.Chain1PanelPostRouting, iptables.ForwardFileName2); err != nil {
|
|
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelPostRouting, err)
|
|
}
|
|
}
|
|
|
|
func iptablesListForward() ([]FireInfo, error) {
|
|
natList, err := iptables.ListForward(iptables.Chain1PanelPreRouting)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list NAT rules: %w", err)
|
|
}
|
|
|
|
var datas []FireInfo
|
|
for _, nat := range natList {
|
|
datas = append(datas, FireInfo{
|
|
Num: nat.Num,
|
|
Protocol: nat.Protocol,
|
|
Port: strings.TrimPrefix(nat.SrcPort, ":"),
|
|
TargetIP: nat.Destination,
|
|
TargetPort: strings.TrimPrefix(nat.DestPort, ":"),
|
|
Interface: nat.InIface,
|
|
})
|
|
}
|
|
|
|
return datas, nil
|
|
}
|
|
|
|
func parsePort(portStr string) (int, error) {
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
|
}
|
|
if port < 1 || port > 65535 {
|
|
return 0, fmt.Errorf("port out of range: %d", port)
|
|
}
|
|
return port, nil
|
|
}
|
|
|
|
func normalizePortSpec(port string) (string, error) {
|
|
value := strings.TrimSpace(port)
|
|
if value == "" {
|
|
return "", fmt.Errorf("port is required")
|
|
}
|
|
|
|
separator := ""
|
|
if strings.Contains(value, "-") {
|
|
separator = "-"
|
|
} else if strings.Contains(value, ":") {
|
|
separator = ":"
|
|
}
|
|
|
|
if separator != "" {
|
|
parts := strings.Split(value, separator)
|
|
if len(parts) != 2 {
|
|
return "", fmt.Errorf("invalid port range: %s", port)
|
|
}
|
|
start, err := parsePort(strings.TrimSpace(parts[0]))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
end, err := parsePort(strings.TrimSpace(parts[1]))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if start > end {
|
|
return "", fmt.Errorf("invalid port range: %d-%d", start, end)
|
|
}
|
|
return fmt.Sprintf("%d:%d", start, end), nil
|
|
}
|
|
|
|
single, err := parsePort(value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("%d", single), nil
|
|
}
|