package service import ( "errors" "fmt" "net" "strings" "github.com/1Panel-dev/1Panel/agent/app/dto" "github.com/1Panel-dev/1Panel/agent/app/model" "github.com/1Panel-dev/1Panel/agent/global" "github.com/1Panel-dev/1Panel/agent/utils/cmd" "github.com/1Panel-dev/1Panel/agent/utils/firewall/client" "github.com/1Panel-dev/1Panel/agent/utils/firewall/client/iptables" ) type IIptablesService interface { Search(req dto.SearchPageWithType) (int64, interface{}, error) OperateRule(req dto.IptablesRuleOp) error BatchOperate(req dto.IptablesBatchOperate) error LoadChainStatus(req dto.OperationWithName) dto.IptablesChainStatus Operate(req dto.IptablesOp) error } type IptablesService struct{} func NewIIptablesService() IIptablesService { return &IptablesService{} } func (s *IptablesService) Search(req dto.SearchPageWithType) (int64, interface{}, error) { rules, err := iptables.ReadFilterRulesByChain(req.Type) if err != nil { return 0, nil, fmt.Errorf("failed to read iptables rules: %w", err) } var records []iptables.FilterRules total, start, end := len(rules), (req.Page-1)*req.PageSize, req.Page*req.PageSize if start > total { records = make([]iptables.FilterRules, 0) } else { if end >= total { end = total } records = rules[start:end] } rulesInDB, _ := hostRepo.ListFirewallRecord(hostRepo.WithByChain(req.Type)) for i := 0; i < len(records); i++ { for _, item := range rulesInDB { if records[i].Strategy == item.Strategy && records[i].DstIP == item.DstIP && fmt.Sprintf("%v", records[i].DstPort) == item.DstPort && records[i].Protocol == item.Protocol && records[i].SrcIP == item.SrcIP && fmt.Sprintf("%v", records[i].SrcPort) == item.SrcPort { records[i].ID = item.ID records[i].Description = item.Description } } } return int64(total), records, nil } func (s *IptablesService) OperateRule(req dto.IptablesRuleOp) error { action := "ACCEPT" if req.Strategy == "drop" { action = "DROP" } policy := iptables.FilterRules{ Protocol: req.Protocol, SrcIP: req.SrcIP, DstIP: req.DstIP, Strategy: action, } if req.SrcPort != 0 { policy.SrcPort = fmt.Sprintf("%v", req.SrcPort) } if req.DstPort != 0 { policy.DstPort = fmt.Sprintf("%v", req.DstPort) } name := iptables.InputFileName if req.Chain == iptables.Chain1PanelOutput { name = iptables.OutputFileName } switch req.Operation { case "add": if err := s.validateRuleInput(&req); err != nil { return err } if err := iptables.AddFilterRule(req.Chain, policy); err != nil { return fmt.Errorf("failed to add iptables rule: %w", err) } if len(req.Description) != 0 { rule := &model.Firewall{ Chain: req.Chain, Protocol: req.Protocol, SrcIP: req.SrcIP, SrcPort: policy.SrcPort, DstIP: req.DstIP, DstPort: policy.DstPort, Strategy: req.Strategy, Description: req.Description, } if err := hostRepo.SaveFirewallRecord(rule); err != nil { return fmt.Errorf("failed to save rule to database: %w", err) } } case "remove": if err := iptables.DeleteFilterRule(req.Chain, policy); err != nil { return fmt.Errorf("failed to remove iptables rule: %w", err) } if req.ID != 0 { if err := hostRepo.DeleteFirewallRecordByID(req.ID); err != nil { return fmt.Errorf("failed to delete rule from database: %w", err) } } } if err := iptables.SaveRulesToFile(iptables.FilterTab, req.Chain, name); err != nil { global.LOG.Errorf("persistence for %s failed, err: %v", iptables.Chain1PanelBasic, err) } return nil } func (s *IptablesService) BatchOperate(req dto.IptablesBatchOperate) error { for _, rule := range req.Rules { if err := s.OperateRule(rule); err != nil { return err } } return nil } func (s *IptablesService) Operate(req dto.IptablesOp) error { targetChain := iptables.ChainInput if req.Name == iptables.Chain1PanelOutput { targetChain = iptables.ChainOutput } switch req.Operate { case "init-base": if ok := cmd.Which("iptables"); !ok { return fmt.Errorf("failed to find iptables") } if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicBefore); err != nil { return err } if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasic); err != nil { return err } if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelBasicAfter); err != nil { return err } if err := initPreRules(); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore, 1); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic, 2); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil { return err } return nil case "init-forward": return client.EnableIptablesForward() case "init-advance": if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelInput); err != nil { return err } if err := iptables.AddChain(iptables.FilterTab, iptables.Chain1PanelOutput); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainOutput, iptables.Chain1PanelOutput, 1); err != nil { return err } number := loadBindNumber(iptables.Chain1PanelInput) if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelInput, number); err != nil { return err } return nil case "bind-base": if err := initPreRules(); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore, 1); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic, 2); err != nil { return err } if err := iptables.BindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter, 3); err != nil { return err } return nil case "unbind-base": if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicAfter); err != nil { return err } if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasicBefore); err != nil { return err } if err := iptables.UnbindChain(iptables.FilterTab, iptables.ChainInput, iptables.Chain1PanelBasic); err != nil { return err } return nil case "bind": if err := iptables.BindChain(iptables.FilterTab, targetChain, req.Name, loadBindNumber(req.Name)); err != nil { return err } return nil case "unbind": if err := iptables.UnbindChain(iptables.FilterTab, targetChain, req.Name); err != nil { return err } return nil } return nil } func (s *IptablesService) LoadChainStatus(req dto.OperationWithName) dto.IptablesChainStatus { var data dto.IptablesChainStatus var err error data.DefaultStrategy, err = iptables.LoadDefaultStrategy(req.Name) if err != nil { global.LOG.Error(err) } switch req.Name { case iptables.Chain1PanelBasic: data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainInput, req.Name) case iptables.Chain1PanelInput: data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainInput, req.Name) case iptables.Chain1PanelOutput: data.IsBind, err = iptables.CheckChainBind(iptables.FilterTab, iptables.ChainOutput, req.Name) } return data } func (s *IptablesService) validateRuleInput(req *dto.IptablesRuleOp) error { if req.Protocol != "" { validProtocols := map[string]bool{"tcp": true, "udp": true, "icmp": true, "all": true} if !validProtocols[strings.ToLower(req.Protocol)] { return fmt.Errorf("invalid protocol: %s, must be tcp, udp, icmp or all", req.Protocol) } } if req.SrcIP != "" { if err := s.validateIPOrCIDR(req.SrcIP); err != nil { return fmt.Errorf("invalid source IP: %w", err) } } if req.DstIP != "" { if err := s.validateIPOrCIDR(req.DstIP); err != nil { return fmt.Errorf("invalid destination IP: %w", err) } } if req.SrcPort > 65535 { return fmt.Errorf("invalid source port: %d, must be between 1 and 65535", req.SrcPort) } if req.DstPort > 65535 { return fmt.Errorf("invalid destination port: %d, must be between 1 and 65535", req.DstPort) } if (req.SrcPort > 0 || req.DstPort > 0) && req.Protocol == "" { return fmt.Errorf("port specification requires protocol (tcp/udp)") } return nil } func (s *IptablesService) validateIPOrCIDR(ipStr string) error { if strings.Contains(ipStr, "/") { _, _, err := net.ParseCIDR(ipStr) if err != nil { return fmt.Errorf("invalid CIDR format: %w", err) } return nil } ip := net.ParseIP(ipStr) if ip == nil { return fmt.Errorf("invalid IP address format") } return nil } func loadBindNumber(chain string) int { if chain == iptables.Chain1PanelOutput { return 1 } number := 1 if exist, _ := iptables.CheckChainExist(iptables.FilterTab, iptables.Chain1PanelBasicBefore); exist { number++ } if exist, _ := iptables.CheckChainExist(iptables.FilterTab, iptables.Chain1PanelBasic); exist { number++ } return number } func initPreRules() error { if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, iptables.IoRuleIn); err != nil { return err } if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, iptables.EstablishedRule); err != nil { return err } panelPort := "" if !global.IsMaster { panelPort = global.CONF.Base.Port } else { var portSetting model.Setting _ = global.CoreDB.Where("key = ?", "ServerPort").First(&portSetting).Error if len(portSetting.Value) != 0 { panelPort = portSetting.Value } } if len(panelPort) == 0 { return errors.New("find 1panel service port failed") } ports := []string{"80", "443", panelPort, loadSSHPort()} for _, item := range ports { if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicBefore, fmt.Sprintf("-p tcp -m tcp --dport %v -j ACCEPT", item)); err != nil { return err } } if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicAfter, iptables.DropAllTcp); err != nil { return err } if err := iptables.AddRule(iptables.FilterTab, iptables.Chain1PanelBasicAfter, iptables.DropAllUdp); err != nil { return err } return nil }