diff --git a/integrationTest/provider_test.go b/integrationTest/provider_test.go index 1b332c967..933ae9187 100644 --- a/integrationTest/provider_test.go +++ b/integrationTest/provider_test.go @@ -51,8 +51,8 @@ func TestDualProviders(t *testing.T) { run() // add bogus nameservers dc.Records = []*models.RecordConfig{} - nslist, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"}) - dc.Nameservers = append(dc.Nameservers, nslist...) + nsList, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"}) + dc.Nameservers = append(dc.Nameservers, nsList...) nameservers.AddNSRecords(dc) t.Log("Adding test nameservers") run() @@ -118,3 +118,95 @@ func TestNameserverDots(t *testing.T) { } }) } + +// TestDuplicateNameservers verifies that a provider de-dupes nameservers if existing nameservers are added. +func TestDuplicateNameservers(t *testing.T) { + // Issue: https://github.com/StackExchange/dnscontrol/issues/3088 + // Only configuring for Azure DNS + + // Setup: + p, domain, cfg := getProvider(t) + if p == nil { + return + } + + if domain == "" { + t.Fatal("NO DOMAIN SET! Exiting!") + } + + dc := getDomainConfigWithNameservers(t, p, domain) + + if !providers.ProviderHasCapability(*providerFlag, providers.DocDualHost) { + t.Skip("Skipping. DocDualHost == Cannot") + return + } + + if cfg["TYPE"] != "AZURE_DNS" { + t.Skip("Skipping. Deduplication logic is not implemented for this provider.") + return + } + + // clear everything + run := func(expectedChangeCount int, msg string, ignoreExpected bool) { + dom, _ := dc.Copy() + + rs, cs, actualChangeCount, err := zonerecs.CorrectZoneRecords(p, dom) + if err != nil { + t.Fatal(err) + } + for i, c := range rs { + t.Logf("INFO#%d:\n%s", i+1, c.Msg) + } + for i, c := range cs { + t.Logf("#%d:\n%s", i+1, c.Msg) + if err = c.F(); err != nil { + t.Fatal(err) + } + } + if (!ignoreExpected) && actualChangeCount != expectedChangeCount { + t.Logf(msg, actualChangeCount) + t.FailNow() + } + } + + t.Log("Clearing everything") + dc.Records = []*models.RecordConfig{} + n := 0 + for _, ns := range dc.Nameservers { + if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" { + continue + } + dc.Nameservers[n] = ns + n++ + } + dc.Nameservers = dc.Nameservers[:n] + nameservers.AddNSRecords(dc) + run(0, "Clearing everything", true) + + // add bogus nameservers and duplicate nameservers + dc.Records = []*models.RecordConfig{} + nsList, _ := models.ToNameservers([]string{"ns1.example.com"}) + dc.Nameservers = append(dc.Nameservers, dc.Nameservers...) + dc.Nameservers = append(dc.Nameservers, nsList...) + nameservers.AddNSRecords(dc) + t.Log("Adding test nameservers") + run(1, "Expect 1 correction, but found %d.", false) + + // run again to make sure no corrections + t.Log("Running again to ensure stability") + run(0, "Expect no corrections on second run, but found %d.", false) + + t.Log("Removing test nameservers") + dc.Records = []*models.RecordConfig{} + n = 0 + for _, ns := range dc.Nameservers { + if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" { + continue + } + dc.Nameservers[n] = ns + n++ + } + dc.Nameservers = dc.Nameservers[:n] + nameservers.AddNSRecords(dc) + run(0, "Removing test nameservers", true) +} diff --git a/models/record.go b/models/record.go index 224ec03b5..b19cc184d 100644 --- a/models/record.go +++ b/models/record.go @@ -611,7 +611,7 @@ func Downcase(recs []*RecordConfig) { r.Name = strings.ToLower(r.Name) r.NameFQDN = strings.ToLower(r.NameFQDN) switch r.Type { // #rtype_variations - case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA": + case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA", "AZURE_ALIAS": // Target is case insensitive. Downcase it. r.target = strings.ToLower(r.target) // BUGFIX(tlim): isn't ALIAS in the wrong case statement? diff --git a/providers/azuredns/azureDnsProvider.go b/providers/azuredns/azureDnsProvider.go index 37032be4d..cdf9e54a8 100644 --- a/providers/azuredns/azureDnsProvider.go +++ b/providers/azuredns/azureDnsProvider.go @@ -3,8 +3,10 @@ package azuredns import ( "context" "encoding/json" + "errors" "fmt" "net/http" + "slices" "strings" "time" @@ -173,6 +175,13 @@ func (a *azurednsProvider) GetNameservers(domain string) ([]*models.Nameserver, for _, ns := range zone.Properties.NameServers { nss = append(nss, *ns) } + + nonDefaultNss, err := a.getNameNonDefaultNameServers(domain, nss) + if err != nil { + return nil, err + } + + nss = append(nss, nonDefaultNss...) } return models.ToNameserversStripTD(nss) @@ -193,6 +202,52 @@ func (a *azurednsProvider) ListZones() ([]string, error) { return zones, nil } +func (a *azurednsProvider) getNameNonDefaultNameServers(domain string, nss []string) ([]string, error) { + zone, ok := a.zones[domain] + if !ok { + return nil, errNoExist{domain} + } + zoneName := *zone.Name + var nameServers []string + ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) + defer cancel() + recordsPager := a.recordsClient.NewListByTypePager(*a.resourceGroup, zoneName, "NS", nil) + + for recordsPager.More() { + waitTime := 1 + retry: + nextResult, recordsErr := recordsPager.NextPage(ctx) + + if recordsErr != nil { + err := recordsErr + var e *azcore.ResponseError + if errors.As(err, &e) { + if e.StatusCode == http.StatusTooManyRequests { + waitTime = waitTime * 2 + if waitTime > 300 { + return nil, err + } + printer.Printf("AZURE_DNS: rate-limit paused for %v.\n", waitTime) + time.Sleep(time.Duration(waitTime+1) * time.Second) + goto retry + } + } + } + + for _, record := range nextResult.Value { + if record.Properties != nil && domain == removeTrailingDot(*record.Properties.Fqdn) { + for _, ns := range record.Properties.NsRecords { + if !slices.Contains(nss, *ns.Nsdname) { + nameServers = append(nameServers, *ns.Nsdname) + } + } + } + } + } + + return nameServers, nil +} + // GetZoneRecords gets the records of a zone and returns them in RecordConfig format. func (a *azurednsProvider) GetZoneRecords(domain string, meta map[string]string) (models.Records, error) { existingRecords, _, _, err := a.getExistingRecords(domain) @@ -239,6 +294,12 @@ func (a *azurednsProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, ex dcn := dc.Name chaKey := change.Key + if change.Type == diff2.CHANGE || change.Type == diff2.CREATE { + if chaKey.Type == "NS" && dcn == removeTrailingDot(change.Key.NameFQDN) { + change.New = deduplicateNameServerTargets(change.New) + } + } + switch change.Type { case diff2.REPORT: corrections = append(corrections, &models.Correction{Msg: change.MsgsJoined}) @@ -280,13 +341,14 @@ func (a *azurednsProvider) recordCreate(zoneName string, reckey models.RecordKey rrset.Properties.TTL = &i waitTime := 1 -retry: +retry: ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) defer cancel() _, err = a.recordsClient.CreateOrUpdate(ctx, *a.resourceGroup, zoneName, recordName, azRecType, *rrset, nil) - if e, ok := err.(*azcore.ResponseError); ok { + var e *azcore.ResponseError + if errors.As(err, &e) { if e.StatusCode == http.StatusTooManyRequests { waitTime = waitTime * 2 if waitTime > 300 { @@ -319,7 +381,8 @@ retry: defer cancel() _, err = a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, shortName, azRecType, nil) - if e, ok := err.(*azcore.ResponseError); ok { + var e *azcore.ResponseError + if errors.As(err, &e) { if e.StatusCode == http.StatusTooManyRequests { waitTime = waitTime * 2 if waitTime > 300 { @@ -608,7 +671,8 @@ func (a *azurednsProvider) fetchRecordSets(zoneName string) ([]*adns.RecordSet, if recordsErr != nil { err := recordsErr - if e, ok := err.(*azcore.ResponseError); ok { + var e *azcore.ResponseError + if errors.As(err, &e) { if e.StatusCode == http.StatusTooManyRequests { waitTime = waitTime * 2 if waitTime > 300 { @@ -643,3 +707,19 @@ func (a *azurednsProvider) EnsureZoneExists(domain string) error { a.zones[domain] = &z.Zone return nil } + +func removeTrailingDot(record string) string { + return strings.TrimSuffix(record, ".") +} + +func deduplicateNameServerTargets(newRecs models.Records) models.Records { + dedupedMap := make(map[string]bool) + var deduped models.Records + for _, rec := range newRecs { + if !dedupedMap[rec.GetTargetField()] { + dedupedMap[rec.GetTargetField()] = true + deduped = append(deduped, rec) + } + } + return deduped +}