mirror of
https://github.com/StackExchange/dnscontrol.git
synced 2025-09-06 05:04:29 +08:00
AZURE_DNS: Dedupe nameserver (#3526)
Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
parent
c204ccea09
commit
697433563f
3 changed files with 179 additions and 7 deletions
|
@ -51,8 +51,8 @@ func TestDualProviders(t *testing.T) {
|
|||
run()
|
||||
// add bogus nameservers
|
||||
dc.Records = []*models.RecordConfig{}
|
||||
nslist, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"})
|
||||
dc.Nameservers = append(dc.Nameservers, nslist...)
|
||||
nsList, _ := models.ToNameservers([]string{"ns1.example.com", "ns2.example.com"})
|
||||
dc.Nameservers = append(dc.Nameservers, nsList...)
|
||||
nameservers.AddNSRecords(dc)
|
||||
t.Log("Adding test nameservers")
|
||||
run()
|
||||
|
@ -118,3 +118,95 @@ func TestNameserverDots(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDuplicateNameservers verifies that a provider de-dupes nameservers if existing nameservers are added.
|
||||
func TestDuplicateNameservers(t *testing.T) {
|
||||
// Issue: https://github.com/StackExchange/dnscontrol/issues/3088
|
||||
// Only configuring for Azure DNS
|
||||
|
||||
// Setup:
|
||||
p, domain, cfg := getProvider(t)
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if domain == "" {
|
||||
t.Fatal("NO DOMAIN SET! Exiting!")
|
||||
}
|
||||
|
||||
dc := getDomainConfigWithNameservers(t, p, domain)
|
||||
|
||||
if !providers.ProviderHasCapability(*providerFlag, providers.DocDualHost) {
|
||||
t.Skip("Skipping. DocDualHost == Cannot")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg["TYPE"] != "AZURE_DNS" {
|
||||
t.Skip("Skipping. Deduplication logic is not implemented for this provider.")
|
||||
return
|
||||
}
|
||||
|
||||
// clear everything
|
||||
run := func(expectedChangeCount int, msg string, ignoreExpected bool) {
|
||||
dom, _ := dc.Copy()
|
||||
|
||||
rs, cs, actualChangeCount, err := zonerecs.CorrectZoneRecords(p, dom)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i, c := range rs {
|
||||
t.Logf("INFO#%d:\n%s", i+1, c.Msg)
|
||||
}
|
||||
for i, c := range cs {
|
||||
t.Logf("#%d:\n%s", i+1, c.Msg)
|
||||
if err = c.F(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if (!ignoreExpected) && actualChangeCount != expectedChangeCount {
|
||||
t.Logf(msg, actualChangeCount)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Clearing everything")
|
||||
dc.Records = []*models.RecordConfig{}
|
||||
n := 0
|
||||
for _, ns := range dc.Nameservers {
|
||||
if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" {
|
||||
continue
|
||||
}
|
||||
dc.Nameservers[n] = ns
|
||||
n++
|
||||
}
|
||||
dc.Nameservers = dc.Nameservers[:n]
|
||||
nameservers.AddNSRecords(dc)
|
||||
run(0, "Clearing everything", true)
|
||||
|
||||
// add bogus nameservers and duplicate nameservers
|
||||
dc.Records = []*models.RecordConfig{}
|
||||
nsList, _ := models.ToNameservers([]string{"ns1.example.com"})
|
||||
dc.Nameservers = append(dc.Nameservers, dc.Nameservers...)
|
||||
dc.Nameservers = append(dc.Nameservers, nsList...)
|
||||
nameservers.AddNSRecords(dc)
|
||||
t.Log("Adding test nameservers")
|
||||
run(1, "Expect 1 correction, but found %d.", false)
|
||||
|
||||
// run again to make sure no corrections
|
||||
t.Log("Running again to ensure stability")
|
||||
run(0, "Expect no corrections on second run, but found %d.", false)
|
||||
|
||||
t.Log("Removing test nameservers")
|
||||
dc.Records = []*models.RecordConfig{}
|
||||
n = 0
|
||||
for _, ns := range dc.Nameservers {
|
||||
if ns.Name == "ns1.example.com" || ns.Name == "ns2.example.com" {
|
||||
continue
|
||||
}
|
||||
dc.Nameservers[n] = ns
|
||||
n++
|
||||
}
|
||||
dc.Nameservers = dc.Nameservers[:n]
|
||||
nameservers.AddNSRecords(dc)
|
||||
run(0, "Removing test nameservers", true)
|
||||
}
|
||||
|
|
|
@ -611,7 +611,7 @@ func Downcase(recs []*RecordConfig) {
|
|||
r.Name = strings.ToLower(r.Name)
|
||||
r.NameFQDN = strings.ToLower(r.NameFQDN)
|
||||
switch r.Type { // #rtype_variations
|
||||
case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA":
|
||||
case "AKAMAICDN", "ALIAS", "AAAA", "ANAME", "CNAME", "DNAME", "DS", "DNSKEY", "MX", "NS", "NAPTR", "PTR", "SRV", "TLSA", "AZURE_ALIAS":
|
||||
// Target is case insensitive. Downcase it.
|
||||
r.target = strings.ToLower(r.target)
|
||||
// BUGFIX(tlim): isn't ALIAS in the wrong case statement?
|
||||
|
|
|
@ -3,8 +3,10 @@ package azuredns
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -173,6 +175,13 @@ func (a *azurednsProvider) GetNameservers(domain string) ([]*models.Nameserver,
|
|||
for _, ns := range zone.Properties.NameServers {
|
||||
nss = append(nss, *ns)
|
||||
}
|
||||
|
||||
nonDefaultNss, err := a.getNameNonDefaultNameServers(domain, nss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nss = append(nss, nonDefaultNss...)
|
||||
}
|
||||
|
||||
return models.ToNameserversStripTD(nss)
|
||||
|
@ -193,6 +202,52 @@ func (a *azurednsProvider) ListZones() ([]string, error) {
|
|||
return zones, nil
|
||||
}
|
||||
|
||||
func (a *azurednsProvider) getNameNonDefaultNameServers(domain string, nss []string) ([]string, error) {
|
||||
zone, ok := a.zones[domain]
|
||||
if !ok {
|
||||
return nil, errNoExist{domain}
|
||||
}
|
||||
zoneName := *zone.Name
|
||||
var nameServers []string
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
|
||||
defer cancel()
|
||||
recordsPager := a.recordsClient.NewListByTypePager(*a.resourceGroup, zoneName, "NS", nil)
|
||||
|
||||
for recordsPager.More() {
|
||||
waitTime := 1
|
||||
retry:
|
||||
nextResult, recordsErr := recordsPager.NextPage(ctx)
|
||||
|
||||
if recordsErr != nil {
|
||||
err := recordsErr
|
||||
var e *azcore.ResponseError
|
||||
if errors.As(err, &e) {
|
||||
if e.StatusCode == http.StatusTooManyRequests {
|
||||
waitTime = waitTime * 2
|
||||
if waitTime > 300 {
|
||||
return nil, err
|
||||
}
|
||||
printer.Printf("AZURE_DNS: rate-limit paused for %v.\n", waitTime)
|
||||
time.Sleep(time.Duration(waitTime+1) * time.Second)
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, record := range nextResult.Value {
|
||||
if record.Properties != nil && domain == removeTrailingDot(*record.Properties.Fqdn) {
|
||||
for _, ns := range record.Properties.NsRecords {
|
||||
if !slices.Contains(nss, *ns.Nsdname) {
|
||||
nameServers = append(nameServers, *ns.Nsdname)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nameServers, nil
|
||||
}
|
||||
|
||||
// GetZoneRecords gets the records of a zone and returns them in RecordConfig format.
|
||||
func (a *azurednsProvider) GetZoneRecords(domain string, meta map[string]string) (models.Records, error) {
|
||||
existingRecords, _, _, err := a.getExistingRecords(domain)
|
||||
|
@ -239,6 +294,12 @@ func (a *azurednsProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, ex
|
|||
dcn := dc.Name
|
||||
chaKey := change.Key
|
||||
|
||||
if change.Type == diff2.CHANGE || change.Type == diff2.CREATE {
|
||||
if chaKey.Type == "NS" && dcn == removeTrailingDot(change.Key.NameFQDN) {
|
||||
change.New = deduplicateNameServerTargets(change.New)
|
||||
}
|
||||
}
|
||||
|
||||
switch change.Type {
|
||||
case diff2.REPORT:
|
||||
corrections = append(corrections, &models.Correction{Msg: change.MsgsJoined})
|
||||
|
@ -280,13 +341,14 @@ func (a *azurednsProvider) recordCreate(zoneName string, reckey models.RecordKey
|
|||
rrset.Properties.TTL = &i
|
||||
|
||||
waitTime := 1
|
||||
retry:
|
||||
|
||||
retry:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
|
||||
defer cancel()
|
||||
_, err = a.recordsClient.CreateOrUpdate(ctx, *a.resourceGroup, zoneName, recordName, azRecType, *rrset, nil)
|
||||
|
||||
if e, ok := err.(*azcore.ResponseError); ok {
|
||||
var e *azcore.ResponseError
|
||||
if errors.As(err, &e) {
|
||||
if e.StatusCode == http.StatusTooManyRequests {
|
||||
waitTime = waitTime * 2
|
||||
if waitTime > 300 {
|
||||
|
@ -319,7 +381,8 @@ retry:
|
|||
defer cancel()
|
||||
_, err = a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, shortName, azRecType, nil)
|
||||
|
||||
if e, ok := err.(*azcore.ResponseError); ok {
|
||||
var e *azcore.ResponseError
|
||||
if errors.As(err, &e) {
|
||||
if e.StatusCode == http.StatusTooManyRequests {
|
||||
waitTime = waitTime * 2
|
||||
if waitTime > 300 {
|
||||
|
@ -608,7 +671,8 @@ func (a *azurednsProvider) fetchRecordSets(zoneName string) ([]*adns.RecordSet,
|
|||
|
||||
if recordsErr != nil {
|
||||
err := recordsErr
|
||||
if e, ok := err.(*azcore.ResponseError); ok {
|
||||
var e *azcore.ResponseError
|
||||
if errors.As(err, &e) {
|
||||
if e.StatusCode == http.StatusTooManyRequests {
|
||||
waitTime = waitTime * 2
|
||||
if waitTime > 300 {
|
||||
|
@ -643,3 +707,19 @@ func (a *azurednsProvider) EnsureZoneExists(domain string) error {
|
|||
a.zones[domain] = &z.Zone
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeTrailingDot(record string) string {
|
||||
return strings.TrimSuffix(record, ".")
|
||||
}
|
||||
|
||||
func deduplicateNameServerTargets(newRecs models.Records) models.Records {
|
||||
dedupedMap := make(map[string]bool)
|
||||
var deduped models.Records
|
||||
for _, rec := range newRecs {
|
||||
if !dedupedMap[rec.GetTargetField()] {
|
||||
dedupedMap[rec.GetTargetField()] = true
|
||||
deduped = append(deduped, rec)
|
||||
}
|
||||
}
|
||||
return deduped
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue