MAINTENANCE: Return error instead of panic when converting RR to RC (#1199)

This commit is contained in:
nemunaire 2021-07-06 17:03:29 +02:00 committed by GitHub
parent 00595a895f
commit 80f22df705
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 33 deletions

View file

@ -4,7 +4,6 @@ package models
import ( import (
"fmt" "fmt"
"log"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -42,17 +41,21 @@ func (rc *RecordConfig) String() string {
// Conversions // Conversions
// RRstoRCs converts []dns.RR to []RecordConfigs. // RRstoRCs converts []dns.RR to []RecordConfigs.
func RRstoRCs(rrs []dns.RR, origin string) Records { func RRstoRCs(rrs []dns.RR, origin string) (Records, error) {
rcs := make(Records, 0, len(rrs)) rcs := make(Records, 0, len(rrs))
for _, r := range rrs { for _, r := range rrs {
rc := RRtoRC(r, origin) rc, err := RRtoRC(r, origin)
if err != nil {
return nil, err
}
rcs = append(rcs, &rc) rcs = append(rcs, &rc)
} }
return rcs return rcs, nil
} }
// RRtoRC converts dns.RR to RecordConfig // RRtoRC converts dns.RR to RecordConfig
func RRtoRC(rr dns.RR, origin string) RecordConfig { func RRtoRC(rr dns.RR, origin string) (RecordConfig, error) {
// Convert's dns.RR into our native data type (RecordConfig). // Convert's dns.RR into our native data type (RecordConfig).
// Records are translated directly with no changes. // Records are translated directly with no changes.
header := rr.Header() header := rr.Header()
@ -61,43 +64,41 @@ func RRtoRC(rr dns.RR, origin string) RecordConfig {
rc.TTL = header.Ttl rc.TTL = header.Ttl
rc.Original = rr rc.Original = rr
rc.SetLabelFromFQDN(strings.TrimSuffix(header.Name, "."), origin) rc.SetLabelFromFQDN(strings.TrimSuffix(header.Name, "."), origin)
var err error
switch v := rr.(type) { // #rtype_variations switch v := rr.(type) { // #rtype_variations
case *dns.A: case *dns.A:
panicInvalid(rc.SetTarget(v.A.String())) err = rc.SetTarget(v.A.String())
case *dns.AAAA: case *dns.AAAA:
panicInvalid(rc.SetTarget(v.AAAA.String())) err = rc.SetTarget(v.AAAA.String())
case *dns.CAA: case *dns.CAA:
panicInvalid(rc.SetTargetCAA(v.Flag, v.Tag, v.Value)) err = rc.SetTargetCAA(v.Flag, v.Tag, v.Value)
case *dns.CNAME: case *dns.CNAME:
panicInvalid(rc.SetTarget(v.Target)) err = rc.SetTarget(v.Target)
case *dns.DS: case *dns.DS:
panicInvalid(rc.SetTargetDS(v.KeyTag, v.Algorithm, v.DigestType, v.Digest)) err = rc.SetTargetDS(v.KeyTag, v.Algorithm, v.DigestType, v.Digest)
case *dns.MX: case *dns.MX:
panicInvalid(rc.SetTargetMX(v.Preference, v.Mx)) err = rc.SetTargetMX(v.Preference, v.Mx)
case *dns.NS: case *dns.NS:
panicInvalid(rc.SetTarget(v.Ns)) err = rc.SetTarget(v.Ns)
case *dns.PTR: case *dns.PTR:
panicInvalid(rc.SetTarget(v.Ptr)) err = rc.SetTarget(v.Ptr)
case *dns.NAPTR: case *dns.NAPTR:
panicInvalid(rc.SetTargetNAPTR(v.Order, v.Preference, v.Flags, v.Service, v.Regexp, v.Replacement)) err = rc.SetTargetNAPTR(v.Order, v.Preference, v.Flags, v.Service, v.Regexp, v.Replacement)
case *dns.SOA: case *dns.SOA:
panicInvalid(rc.SetTargetSOA(v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)) err = rc.SetTargetSOA(v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.SRV: case *dns.SRV:
panicInvalid(rc.SetTargetSRV(v.Priority, v.Weight, v.Port, v.Target)) err = rc.SetTargetSRV(v.Priority, v.Weight, v.Port, v.Target)
case *dns.SSHFP: case *dns.SSHFP:
panicInvalid(rc.SetTargetSSHFP(v.Algorithm, v.Type, v.FingerPrint)) err = rc.SetTargetSSHFP(v.Algorithm, v.Type, v.FingerPrint)
case *dns.TLSA: case *dns.TLSA:
panicInvalid(rc.SetTargetTLSA(v.Usage, v.Selector, v.MatchingType, v.Certificate)) err = rc.SetTargetTLSA(v.Usage, v.Selector, v.MatchingType, v.Certificate)
case *dns.TXT: case *dns.TXT:
panicInvalid(rc.SetTargetTXTs(v.Txt)) err = rc.SetTargetTXTs(v.Txt)
default: default:
log.Fatalf("rrToRecord: Unimplemented zone record type=%s (%v)\n", rc.Type, rr) return *rc, fmt.Errorf("rrToRecord: Unimplemented zone record type=%s (%v)\n", rc.Type, rr)
} }
return *rc
}
func panicInvalid(err error) {
if err != nil { if err != nil {
panic(fmt.Errorf("unparsable record received from BIND: %w", err)) return *rc, fmt.Errorf("unparsable record received: %w", err)
} }
return *rc, nil
} }

View file

@ -45,7 +45,12 @@ func MostCommonTTL(records models.Records) uint32 {
// WriteZoneFileRR is a helper for when you have []dns.RR instead of models.Records // WriteZoneFileRR is a helper for when you have []dns.RR instead of models.Records
func WriteZoneFileRR(w io.Writer, records []dns.RR, origin string) error { func WriteZoneFileRR(w io.Writer, records []dns.RR, origin string) error {
return WriteZoneFileRC(w, models.RRstoRCs(records, origin), origin, 0, nil) rcs, err := models.RRstoRCs(records, origin)
if err != nil {
return err
}
return WriteZoneFileRC(w, rcs, origin, 0, nil)
} }
// WriteZoneFileRC writes a beautifully formatted zone file. // WriteZoneFileRC writes a beautifully formatted zone file.

View file

@ -48,7 +48,10 @@ func TestMostCommonTtl(t *testing.T) {
// All records are TTL=100 // All records are TTL=100
records = nil records = nil
records, e = append(records, r1, r1, r1), 100 records, e = append(records, r1, r1, r1), 100
x := models.RRstoRCs(records, "bosun.org") x, err := models.RRstoRCs(records, "bosun.org")
if err != nil {
panic(err)
}
g = MostCommonTTL(x) g = MostCommonTTL(x)
if e != g { if e != g {
t.Fatalf("expected %d; got %d\n", e, g) t.Fatalf("expected %d; got %d\n", e, g)
@ -57,7 +60,11 @@ func TestMostCommonTtl(t *testing.T) {
// Mixture of TTLs with an obvious winner. // Mixture of TTLs with an obvious winner.
records = nil records = nil
records, e = append(records, r1, r2, r2), 200 records, e = append(records, r1, r2, r2), 200
g = MostCommonTTL(models.RRstoRCs(records, "bosun.org")) rcs, err := models.RRstoRCs(records, "bosun.org")
if err != nil {
panic(err)
}
g = MostCommonTTL(rcs)
if e != g { if e != g {
t.Fatalf("expected %d; got %d\n", e, g) t.Fatalf("expected %d; got %d\n", e, g)
} }
@ -65,7 +72,11 @@ func TestMostCommonTtl(t *testing.T) {
// 3-way tie. Largest TTL should be used. // 3-way tie. Largest TTL should be used.
records = nil records = nil
records, e = append(records, r1, r2, r3), 300 records, e = append(records, r1, r2, r3), 300
g = MostCommonTTL(models.RRstoRCs(records, "bosun.org")) rcs, err = models.RRstoRCs(records, "bosun.org")
if err != nil {
panic(err)
}
g = MostCommonTTL(rcs)
if e != g { if e != g {
t.Fatalf("expected %d; got %d\n", e, g) t.Fatalf("expected %d; got %d\n", e, g)
} }
@ -73,7 +84,11 @@ func TestMostCommonTtl(t *testing.T) {
// NS records are ignored. // NS records are ignored.
records = nil records = nil
records, e = append(records, r1, r4, r5), 100 records, e = append(records, r1, r4, r5), 100
g = MostCommonTTL(models.RRstoRCs(records, "bosun.org")) rcs, err = models.RRstoRCs(records, "bosun.org")
if err != nil {
panic(err)
}
g = MostCommonTTL(rcs)
if e != g { if e != g {
t.Fatalf("expected %d; got %d\n", e, g) t.Fatalf("expected %d; got %d\n", e, g)
} }
@ -289,7 +304,10 @@ func TestWriteZoneFileSynth(t *testing.T) {
rsynz := &models.RecordConfig{Type: "R53_ALIAS", TTL: 300} rsynz := &models.RecordConfig{Type: "R53_ALIAS", TTL: 300}
rsynz.SetLabel("zalias", "bosun.org") rsynz.SetLabel("zalias", "bosun.org")
recs := models.RRstoRCs([]dns.RR{r1, r2, r3}, "bosun.org") recs, err := models.RRstoRCs([]dns.RR{r1, r2, r3}, "bosun.org")
if err != nil {
panic(err)
}
recs = append(recs, rsynm) recs = append(recs, rsynm)
recs = append(recs, rsynm) recs = append(recs, rsynm)
recs = append(recs, rsynz) recs = append(recs, rsynz)

View file

@ -286,7 +286,10 @@ func (c *axfrddnsProvider) GetZoneRecords(domain string) (models.Records, error)
} }
continue continue
default: default:
rec := models.RRtoRC(rr, domain) rec, err := models.RRtoRC(rr, domain)
if err != nil {
return nil, err
}
foundRecords = append(foundRecords, &rec) foundRecords = append(foundRecords, &rec)
} }
} }

View file

@ -171,7 +171,10 @@ func (c *bindProvider) GetZoneRecords(domain string) (models.Records, error) {
zp := dns.NewZoneParser(strings.NewReader(string(content)), domain, c.zonefile) zp := dns.NewZoneParser(strings.NewReader(string(content)), domain, c.zonefile)
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
rec := models.RRtoRC(rr, domain) rec, err := models.RRtoRC(rr, domain)
if err != nil {
return nil, err
}
// FIXME(tlim): Empty branch? Is the intention to skip SOAs? // FIXME(tlim): Empty branch? Is the intention to skip SOAs?
if rec.Type == "SOA" { if rec.Type == "SOA" {
} }