AZURE_DNS: Dedupe nameserver (#3526)

Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
Vatsalya Goel 2025-05-03 22:36:31 +10:00 committed by GitHub
parent c204ccea09
commit 697433563f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 179 additions and 7 deletions

View file

@ -51,8 +51,8 @@ func TestDualProviders(t *testing.T) {
run()
// add bogus nameservers
dc.Records = []*models.RecordConfig{}
nslist, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"})
dc.Nameservers = append(dc.Nameservers, nslist...)
nsList, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"})
dc.Nameservers = append(dc.Nameservers, nsList...)
nameservers.AddNSRecords(dc)
t.Log("Adding test nameservers")
run()
@ -118,3 +118,95 @@ func TestNameserverDots(t *testing.T) {
}
})
}
// TestDuplicateNameservers verifies that a provider de-dupes nameservers if existing nameservers are added.
func TestDuplicateNameservers(t *testing.T) {
// Issue: https://github.com/StackExchange/dnscontrol/issues/3088
// Only configuring for Azure DNS
// Setup:
p, domain, cfg := getProvider(t)
if p == nil {
return
}
if domain == "" {
t.Fatal("NO DOMAIN SET! Exiting!")
}
dc := getDomainConfigWithNameservers(t, p, domain)
if !providers.ProviderHasCapability(*providerFlag, providers.DocDualHost) {
t.Skip("Skipping. DocDualHost == Cannot")
return
}
if cfg["TYPE"] != "AZURE_DNS" {
t.Skip("Skipping. Deduplication logic is not implemented for this provider.")
return
}
// clear everything
run := func(expectedChangeCount int, msg string, ignoreExpected bool) {
dom, _ := dc.Copy()
rs, cs, actualChangeCount, err := zonerecs.CorrectZoneRecords(p, dom)
if err != nil {
t.Fatal(err)
}
for i, c := range rs {
t.Logf("INFO#%d:\n%s", i+1, c.Msg)
}
for i, c := range cs {
t.Logf("#%d:\n%s", i+1, c.Msg)
if err = c.F(); err != nil {
t.Fatal(err)
}
}
if (!ignoreExpected) && actualChangeCount != expectedChangeCount {
t.Logf(msg, actualChangeCount)
t.FailNow()
}
}
t.Log("Clearing everything")
dc.Records = []*models.RecordConfig{}
n := 0
for _, ns := range dc.Nameservers {
if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" {
continue
}
dc.Nameservers[n] = ns
n++
}
dc.Nameservers = dc.Nameservers[:n]
nameservers.AddNSRecords(dc)
run(0, "Clearing everything", true)
// add bogus nameservers and duplicate nameservers
dc.Records = []*models.RecordConfig{}
nsList, _ := models.ToNameservers([]string{"ns1.example.com"})
dc.Nameservers = append(dc.Nameservers, dc.Nameservers...)
dc.Nameservers = append(dc.Nameservers, nsList...)
nameservers.AddNSRecords(dc)
t.Log("Adding test nameservers")
run(1, "Expect 1 correction, but found %d.", false)
// run again to make sure no corrections
t.Log("Running again to ensure stability")
run(0, "Expect no corrections on second run, but found %d.", false)
t.Log("Removing test nameservers")
dc.Records = []*models.RecordConfig{}
n = 0
for _, ns := range dc.Nameservers {
if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" {
continue
}
dc.Nameservers[n] = ns
n++
}
dc.Nameservers = dc.Nameservers[:n]
nameservers.AddNSRecords(dc)
run(0, "Removing test nameservers", true)
}

View file

@ -611,7 +611,7 @@ func Downcase(recs []*RecordConfig) {
r.Name = strings.ToLower(r.Name)
r.NameFQDN = strings.ToLower(r.NameFQDN)
switch r.Type { // #rtype_variations
case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA":
case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA", "AZURE_ALIAS":
// Target is case insensitive. Downcase it.
r.target = strings.ToLower(r.target)
// BUGFIX(tlim): isn't ALIAS in the wrong case statement?

View file

@ -3,8 +3,10 @@ package azuredns
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
@ -173,6 +175,13 @@ func (a *azurednsProvider) GetNameservers(domain string) ([]*models.Nameserver,
for _, ns := range zone.Properties.NameServers {
nss = append(nss, *ns)
}
nonDefaultNss, err := a.getNameNonDefaultNameServers(domain, nss)
if err != nil {
return nil, err
}
nss = append(nss, nonDefaultNss...)
}
return models.ToNameserversStripTD(nss)
@ -193,6 +202,52 @@ func (a *azurednsProvider) ListZones() ([]string, error) {
return zones, nil
}
func (a *azurednsProvider) getNameNonDefaultNameServers(domain string, nss []string) ([]string, error) {
zone, ok := a.zones[domain]
if !ok {
return nil, errNoExist{domain}
}
zoneName := *zone.Name
var nameServers []string
ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
defer cancel()
recordsPager := a.recordsClient.NewListByTypePager(*a.resourceGroup, zoneName, "NS", nil)
for recordsPager.More() {
waitTime := 1
retry:
nextResult, recordsErr := recordsPager.NextPage(ctx)
if recordsErr != nil {
err := recordsErr
var e *azcore.ResponseError
if errors.As(err, &e) {
if e.StatusCode == http.StatusTooManyRequests {
waitTime = waitTime * 2
if waitTime > 300 {
return nil, err
}
printer.Printf("AZURE_DNS: rate-limit paused for %v.\n", waitTime)
time.Sleep(time.Duration(waitTime+1) * time.Second)
goto retry
}
}
}
for _, record := range nextResult.Value {
if record.Properties != nil && domain == removeTrailingDot(*record.Properties.Fqdn) {
for _, ns := range record.Properties.NsRecords {
if !slices.Contains(nss, *ns.Nsdname) {
nameServers = append(nameServers, *ns.Nsdname)
}
}
}
}
}
return nameServers, nil
}
// GetZoneRecords gets the records of a zone and returns them in RecordConfig format.
func (a *azurednsProvider) GetZoneRecords(domain string, meta map[string]string) (models.Records, error) {
existingRecords, _, _, err := a.getExistingRecords(domain)
@ -239,6 +294,12 @@ func (a *azurednsProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, ex
dcn := dc.Name
chaKey := change.Key
if change.Type == diff2.CHANGE || change.Type == diff2.CREATE {
if chaKey.Type == "NS" && dcn == removeTrailingDot(change.Key.NameFQDN) {
change.New = deduplicateNameServerTargets(change.New)
}
}
switch change.Type {
case diff2.REPORT:
corrections = append(corrections, &models.Correction{Msg: change.MsgsJoined})
@ -280,13 +341,14 @@ func (a *azurednsProvider) recordCreate(zoneName string, reckey models.RecordKey
rrset.Properties.TTL = &i
waitTime := 1
retry:
retry:
ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
defer cancel()
_, err = a.recordsClient.CreateOrUpdate(ctx, *a.resourceGroup, zoneName, recordName, azRecType, *rrset, nil)
if e, ok := err.(*azcore.ResponseError); ok {
var e *azcore.ResponseError
if errors.As(err, &e) {
if e.StatusCode == http.StatusTooManyRequests {
waitTime = waitTime * 2
if waitTime > 300 {
@ -319,7 +381,8 @@ retry:
defer cancel()
_, err = a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, shortName, azRecType, nil)
if e, ok := err.(*azcore.ResponseError); ok {
var e *azcore.ResponseError
if errors.As(err, &e) {
if e.StatusCode == http.StatusTooManyRequests {
waitTime = waitTime * 2
if waitTime > 300 {
@ -608,7 +671,8 @@ func (a *azurednsProvider) fetchRecordSets(zoneName string) ([]*adns.RecordSet,
if recordsErr != nil {
err := recordsErr
if e, ok := err.(*azcore.ResponseError); ok {
var e *azcore.ResponseError
if errors.As(err, &e) {
if e.StatusCode == http.StatusTooManyRequests {
waitTime = waitTime * 2
if waitTime > 300 {
@ -643,3 +707,19 @@ func (a *azurednsProvider) EnsureZoneExists(domain string) error {
a.zones[domain] = &z.Zone
return nil
}
func removeTrailingDot(record string) string {
return strings.TrimSuffix(record, ".")
}
func deduplicateNameServerTargets(newRecs models.Records) models.Records {
dedupedMap := make(map[string]bool)
var deduped models.Records
for _, rec := range newRecs {
if !dedupedMap[rec.GetTargetField()] {
dedupedMap[rec.GetTargetField()] = true
deduped = append(deduped, rec)
}
}
return deduped
}