1Panel/agent/app/service/iptables.go

323 lines
9.8 KiB
Go

package service
import (
"errors"
"fmt"
"net"
"strings"
"github.com/1Panel-dev/1Panel/agent/app/dto"
"github.com/1Panel-dev/1Panel/agent/app/model"
"github.com/1Panel-dev/1Panel/agent/global"
"github.com/1Panel-dev/1Panel/agent/utils/cmd"
"github.com/1Panel-dev/1Panel/agent/utils/firewall/client"
"github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables"
)
type IIptablesService interface {
Search(req dto.SearchPageWithType) (int64, interface{}, error)
OperateRule(req dto.IptablesRuleOp) error
BatchOperate(req dto.IptablesBatchOperate) error
LoadChainStatus(req dto.OperationWithName) dto.IptablesChainStatus
Operate(req dto.IptablesOp) error
}
type IptablesService struct{}
func NewIIptablesService() IIptablesService {
return &IptablesService{}
}
func (s *IptablesService) Search(req dto.SearchPageWithType) (int64, interface{}, error) {
rules, err := iptables.ReadFilterRulesByChain(req.Type)
if err != nil {
return 0, nil, fmt.Errorf("failed to read iptables rules: %w", err)
}
var records []iptables.FilterRules
total, start, end := len(rules), (req.Page-1)*req.PageSize, req.Page*req.PageSize
if start > total {
records = make([]iptables.FilterRules, 0)
} else {
if end >= total {
end = total
}
records = rules[start:end]
}
rulesInDB, _ := hostRepo.ListFirewallRecord(hostRepo.WithByChain(req.Type))
for i := 0; i < len(records); i++ {
for _, item := range rulesInDB {
if records[i].Strategy == item.Strategy &&
records[i].DstIP == item.DstIP &&
fmt.Sprintf("%v", records[i].DstPort) == item.DstPort &&
records[i].Protocol == item.Protocol &&
records[i].SrcIP == item.SrcIP &&
fmt.Sprintf("%v", records[i].SrcPort) == item.SrcPort {
records[i].ID = item.ID
records[i].Description = item.Description
}
}
}
return int64(total), records, nil
}
func (s *IptablesService) OperateRule(req dto.IptablesRuleOp) error {
policy := iptables.FilterRules{
Protocol: req.Protocol,
SrcIP: req.SrcIP,
SrcPort: req.SrcPort,
DstIP: req.DstIP,
DstPort: req.DstPort,
Strategy: req.Strategy,
}
name := iptables.InputFileName
if req.Chain == iptables.Chain1PanelOutput {
name = iptables.OutputFileName
}
switch req.Operation {
case "add":
if err := s.validateRuleInput(&req); err != nil {
return err
}
if err := iptables.AddFilterRule(req.Chain, policy); err != nil {
return fmt.Errorf("failed to add iptables rule: %w", err)
}
rule := &model.Firewall{
Chain: req.Chain,
Protocol: req.Protocol,
SrcIP: req.SrcIP,
SrcPort: fmt.Sprintf("%v", req.SrcPort),
DstIP: req.DstIP,
DstPort: fmt.Sprintf("%v", req.DstPort),
Strategy: req.Strategy,
Description: req.Description,
}
if err := hostRepo.SaveFirewallRecord(rule); err != nil {
return fmt.Errorf("failed to save rule to database: %w", err)
}
case "remove":
if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil {
return fmt.Errorf("failed to remove iptables rule: %w", err)
}
if req.ID != 0 {
if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil {
return fmt.Errorf("failed to delete rule from database: %w", err)
}
}
}
if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil {
global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err)
}
return nil
}
func (s *IptablesService) BatchOperate(req dto.IptablesBatchOperate) error {
for _, rule := range req.Rules {
if err := s.OperateRule(rule); err != nil {
return err
}
}
return nil
}
func (s *IptablesService) Operate(req dto.IptablesOp) error {
targetChain := iptables.ChainInput
if req.Name == iptables.Chain1PanelOutput {
targetChain = iptables.ChainOutput
}
switch req.Operate {
case "init-base":
if ok := cmd.Which("iptables"); !ok {
return fmt.Errorf("failed to find iptables")
}
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicBefore); err != nil {
return err
}
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasic); err != nil {
return err
}
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicAfter); err != nil {
return err
}
if err := initPreRules(); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore, 1); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic, 2); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil {
return err
}
return nil
case "init-forward":
return client.EnableIptablesForward()
case "init-advance":
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelInput); err != nil {
return err
}
if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelOutput); err != nil {
return err
}
number := loadBindNumber()
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainOutput, iptables.Chain1PanelOutput, number); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelInput, number); err != nil {
return err
}
return nil
case "bind-base":
if err := initPreRules(); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore, 1); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic, 2); err != nil {
return err
}
if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil {
return err
}
return nil
case "unbind-base":
if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter); err != nil {
return err
}
if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore); err != nil {
return err
}
if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic); err != nil {
return err
}
return nil
case "bind":
if err := iptables.BindChain(iptables.FilterTab, targetChain, req.Name, loadBindNumber()); err != nil {
return err
}
return nil
case "unbind":
if err := iptables.UnbindChain(iptables.FilterTab, targetChain, req.Name); err != nil {
return err
}
return nil
}
return nil
}
func (s *IptablesService) LoadChainStatus(req dto.OperationWithName) dto.IptablesChainStatus {
var data dto.IptablesChainStatus
var err error
data.DefaultStrategy, err = iptables.LoadDefaultStrategy(req.Name)
if err != nil {
global.LOG.Error(err)
}
switch req.Name {
case iptables.Chain1PanelBasic:
data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainInput, req.Name)
case iptables.Chain1PanelInput:
data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainInput, req.Name)
case iptables.Chain1PanelOutput:
data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainOutput, req.Name)
}
return data
}
func (s *IptablesService) validateRuleInput(req *dto.IptablesRuleOp) error {
if req.Protocol != "" {
validProtocols := map[string]bool{"tcp": true, "udp": true, "icmp": true, "all": true}
if !validProtocols[strings.ToLower(req.Protocol)] {
return fmt.Errorf("invalid protocol: %s, must be tcp, udp, icmp or all", req.Protocol)
}
}
if req.SrcIP != "" {
if err := s.validateIPOrCIDR(req.SrcIP); err != nil {
return fmt.Errorf("invalid source IP: %w", err)
}
}
if req.DstIP != "" {
if err := s.validateIPOrCIDR(req.DstIP); err != nil {
return fmt.Errorf("invalid destination IP: %w", err)
}
}
if req.SrcPort > 65535 {
return fmt.Errorf("invalid source port: %d, must be between 1 and 65535", req.SrcPort)
}
if req.DstPort > 65535 {
return fmt.Errorf("invalid destination port: %d, must be between 1 and 65535", req.DstPort)
}
if (req.SrcPort > 0 || req.DstPort > 0) && req.Protocol == "" {
return fmt.Errorf("port specification requires protocol (tcp/udp)")
}
return nil
}
func (s *IptablesService) validateIPOrCIDR(ipStr string) error {
if strings.Contains(ipStr, "/") {
_, _, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("invalid CIDR format: %w", err)
}
return nil
}
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid IP address format")
}
return nil
}
func loadBindNumber() int {
number := 1
if exist, _ := iptables.CheckChainExist(iptables.FilterTab, iptables.Chain1PanelBasicBefore); exist {
number++
}
if exist, _ := iptables.CheckChainExist(iptables.FilterTab, iptables.Chain1PanelBasic); exist {
number++
}
return number
}
func initPreRules() error {
if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, iptables.IoRuleIn); err != nil {
return err
}
if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, iptables.EstablishedRule); err != nil {
return err
}
panelPort := ""
if !global.IsMaster {
panelPort = global.CONF.Base.Port
} else {
var portSetting model.Setting
_ = global.CoreDB.Where("key = ?", "ServerPort").First(&portSetting).Error
if len(portSetting.Value) != 0 {
panelPort = portSetting.Value
}
}
if len(panelPort) == 0 {
return errors.New("find 1panel service port failed")
}
ports := []string{"80", "443", panelPort, loadSSHPort()}
for _, item := range ports {
if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, fmt.Sprintf("-p tcp -m tcp --dport %v -j ACCEPT", item)); err != nil {
return err
}
}
if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicAfter, iptables.DropAll); err != nil {
return err
}
return nil
}