diff --git a/pkg/normalize/validate.go b/pkg/normalize/validate.go index 664911054..4aa151b9a 100644 --- a/pkg/normalize/validate.go +++ b/pkg/normalize/validate.go @@ -3,6 +3,7 @@ package normalize import ( "fmt" "net" + "sort" "strings" "github.com/StackExchange/dnscontrol/v3/models" @@ -10,6 +11,7 @@ import ( "github.com/StackExchange/dnscontrol/v3/providers" "github.com/miekg/dns" "github.com/miekg/dns/dnsutil" + "golang.org/x/exp/slices" ) // Returns false if target does not validate. @@ -581,9 +583,10 @@ func checkDuplicates(records []*models.RecordConfig) (errs []error) { return errs } -func uniq(s []uint32) []uint32 { - seen := make(map[uint32]struct{}) - var result []uint32 +// uniq returns the unique values in a map. The result is sorted lexigraphically. +func uniq(s []string) []string { + seen := make(map[string]struct{}) + var result []string for _, k := range s { if _, ok := seen[k]; !ok { @@ -591,33 +594,113 @@ func uniq(s []uint32) []uint32 { result = append(result, k) } } + sort.Strings(result) return result } func checkLabelHasMultipleTTLs(records []*models.RecordConfig) (errs []error) { - m := make(map[string][]uint32) - for _, r := range records { - label := fmt.Sprintf("%s %s", r.GetLabelFQDN(), r.Type) + // The RFCs say that all records at a particular label should have + // the same TTL. Most providers don't care, and if they do the + // dnscontrol provider code usually picks the lowest TTL for all of them. - // collect the TTLs at this label. - m[label] = append(m[label], r.TTL) + // General algorithm: + // gather all records at a particular label. + // has[label] -> ttl -> type(s) + // for each label, if there is more than one ttl, output ttl:A/TXT ttl:TXT/NS + + // Find the inconsistencies: + m := make(map[string]map[uint32]map[string]bool) + for _, r := range records { + label := r.GetLabelFQDN() + ttl := r.TTL + rtype := r.Type + + if _, ok := m[label]; !ok { + m[label] = make(map[uint32]map[string]bool) + } + if _, ok := m[label][ttl]; !ok { + m[label][ttl] = make(map[string]bool) + } + m[label][ttl][rtype] = true } - for label := range m { - // The RFCs say that all records at a particular label should have - // the same TTL. Most providers don't care, and if they do the - // code usually picks the lowest TTL for all of them. - // - // If after the uniq() pass we still have more than one ttl, it - // means we have multiple TTLs for that label. - u := uniq(m[label]) - if len(u) > 1 { - errs = append(errs, Warning{fmt.Errorf("label with multipe TTLs: %s (%v)", label, u)}) + labels := make([]string, len(m)) + i := 0 + for k := range m { + labels[i] = k + i++ + } + + sort.Strings(labels) + slices.Compact(labels) + + // Less clear error message: + // for _, label := range labels { + // if len(m[label]) > 1 { + // result := "" + // for ttl, v := range m[label] { + // result += fmt.Sprintf(" %d:", ttl) + + // rtypes := make([]string, len(v)) + // i := 0 + // for k := range v { + // rtypes[i] = k + // i++ + // } + + // result += strings.Join(rtypes, "/") + // } + // errs = append(errs, Warning{fmt.Errorf("inconsistent TTLs at %q:%v", label, result)}) + // } + // } + + // Invert for a more clear error message: + for _, label := range labels { + if len(m[label]) > 1 { + r := make(map[string]map[uint32]bool) + for ttl, rtypes := range m[label] { + for rtype := range rtypes { + if _, ok := r[rtype]; !ok { + r[rtype] = make(map[uint32]bool) + } + r[rtype][ttl] = true + } + } + result := formatInconsistency(r) + errs = append(errs, Warning{fmt.Errorf("inconsistent TTLs at %q: %s", label, result)}) } } + return errs } +func formatInconsistency(r map[string]map[uint32]bool) string { + var rtypeResult []string + for rtype, ttlsMap := range r { + + ttlList := make([]int, len(ttlsMap)) + i := 0 + for k := range ttlsMap { + ttlList[i] = int(k) + i++ + } + + sort.Ints(ttlList) + + rtypeResult = append(rtypeResult, fmt.Sprintf("%s:%v", rtype, commaSepInts(ttlList))) + } + sort.Strings(rtypeResult) + return strings.Join(rtypeResult, " ") +} + +func commaSepInts(list []int) string { + slist := make([]string, len(list)) + for i, v := range list { + slist[i] = fmt.Sprintf("%d", v) + } + return strings.Join(slist, ",") +} + // We pull this out of checkProviderCapabilities() so that it's visible within // the package elsewhere, so that our test suite can look at the list of // capabilities we're checking and make sure that it's up-to-date. diff --git a/pkg/normalize/validate_test.go b/pkg/normalize/validate_test.go index 839b303c4..1efd4d278 100644 --- a/pkg/normalize/validate_test.go +++ b/pkg/normalize/validate_test.go @@ -2,7 +2,6 @@ package normalize import ( "fmt" - "reflect" "testing" "github.com/StackExchange/dnscontrol/v3/models" @@ -355,16 +354,6 @@ func TestCheckDuplicates_dup_ns(t *testing.T) { } } -func TestUniq(t *testing.T) { - a := []uint32{1, 2, 2, 3, 4, 5, 5, 6} - expected := []uint32{1, 2, 3, 4, 5, 6} - - r := uniq(a) - if !reflect.DeepEqual(r, expected) { - t.Error("Deduplicated slice is different than expected") - } -} - func TestCheckLabelHasMultipleTTLs(t *testing.T) { records := []*models.RecordConfig{ // different ttl per record