diff --git a/controllers/node_test.go b/controllers/node_test.go index afc9b65f..02495fd2 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -21,6 +21,10 @@ var linuxHost models.Host func TestCreateEgressGateway(t *testing.T) { var gateway models.EgressGatewayRequest gateway.Ranges = []string{"10.100.100.0/24"} + gateway.RangesWithMetric = append(gateway.RangesWithMetric, models.EgressRangeMetric{ + Network: "10.100.100.0/24", + RouteMetric: 256, + }) gateway.NetID = "skynet" deleteAllNetworks() createNet() diff --git a/logic/gateway.go b/logic/gateway.go index 3367fd91..a72ab673 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -77,6 +77,14 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro if host.FirewallInUse == models.FIREWALL_NONE { return models.Node{}, errors.New("please install iptables or nftables on the device") } + if len(gateway.RangesWithMetric) == 0 && len(gateway.Ranges) > 0 { + for _, rangeI := range gateway.Ranges { + gateway.RangesWithMetric = append(gateway.RangesWithMetric, models.EgressRangeMetric{ + Network: rangeI, + RouteMetric: 256, + }) + } + } for i := len(gateway.Ranges) - 1; i >= 0; i-- { // check if internet gateway IPv4 if gateway.Ranges[i] == "0.0.0.0/0" || gateway.Ranges[i] == "::/0" { @@ -105,9 +113,19 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro node.EgressGatewayRanges = gateway.Ranges node.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled) rangesWithMetric := []string{} - for i, rangeI := range gateway.RangesWithMetric { - rangesWithMetric = append(rangesWithMetric, rangeI.Network) - if rangeI.RouteMetric <= 0 || rangeI.RouteMetric > 999 { + for i := len(gateway.RangesWithMetric) - 1; i >= 0; i-- { + if gateway.RangesWithMetric[i].Network == "0.0.0.0/0" || gateway.RangesWithMetric[i].Network == "::/0" { + // remove inet range + gateway.RangesWithMetric = append(gateway.RangesWithMetric[:i], gateway.RangesWithMetric[i+1:]...) + continue + } + normalized, err := NormalizeCIDR(gateway.Ranges[i]) + if err != nil { + return models.Node{}, err + } + gateway.RangesWithMetric[i].Network = normalized + rangesWithMetric = append(rangesWithMetric, gateway.RangesWithMetric[i].Network) + if gateway.RangesWithMetric[i].RouteMetric <= 0 || gateway.RangesWithMetric[i].RouteMetric > 999 { gateway.RangesWithMetric[i].RouteMetric = 256 } }