DESEC: Implements support for long / multistring txt records (#1204)

* use /auth/account endpoint for token validation
this implements the token validation using the /auth/account api endpoint as suggested in #1177 instead of fetching the domain list

* deSEC: add support for long txt records #996

* deSEC: add support for a different api error response
relates to #996 where we had insufficient error output due to unknown api error format

* deSEC: remove unused fetchDomainList function

* deSEC: improve error handling

* deSEC: support for long / multistring txt records
the previous commit was broken this is now working (CRUD)

* deSEC: document what desecProvider.domainIndex is used for

* deSEC: handle the rate limiting correctly
we try to use the Retry-After header to determine how long we should sleep until retry

* deSEC: further improvement of rate limit handling
we cut off if the Retry-After header exceeds 3 minutes because this might be the daily limit.

Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
Georg 2021-07-08 16:06:54 +02:00 committed by GitHub
parent 0847242e9f
commit 228b57e445
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 33 deletions

View file

@ -4,7 +4,6 @@ package desec
import ( import (
"fmt" "fmt"
"strings"
"github.com/StackExchange/dnscontrol/v3/models" "github.com/StackExchange/dnscontrol/v3/models"
"github.com/StackExchange/dnscontrol/v3/pkg/printer" "github.com/StackExchange/dnscontrol/v3/pkg/printer"
@ -25,9 +24,7 @@ func nativeToRecords(n resourceRecord, origin string) (rcs []*models.RecordConfi
} }
rc.SetLabel(n.Subname, origin) rc.SetLabel(n.Subname, origin)
switch rtype := n.Type; rtype { switch rtype := n.Type; rtype {
case "TXT": default: // "A", "AAAA", "CAA", "NS", "CNAME", "MX", "PTR", "SRV", "TXT"
rc.SetTargetTXT(value)
default: // "A", "AAAA", "CAA", "NS", "CNAME", "MX", "PTR", "SRV"
if err := rc.PopulateFromString(rtype, value, origin); err != nil { if err := rc.PopulateFromString(rtype, value, origin); err != nil {
panic(fmt.Errorf("unparsable record received from deSEC: %w", err)) panic(fmt.Errorf("unparsable record received from deSEC: %w", err))
} }
@ -45,7 +42,6 @@ func recordsToNative(rcs []*models.RecordConfig, origin string) []resourceRecord
var keys = map[models.RecordKey]*resourceRecord{} var keys = map[models.RecordKey]*resourceRecord{}
var zrs []resourceRecord var zrs []resourceRecord
for _, r := range rcs { for _, r := range rcs {
label := r.GetLabel() label := r.GetLabel()
if label == "@" { if label == "@" {
@ -61,9 +57,6 @@ func recordsToNative(rcs []*models.RecordConfig, origin string) []resourceRecord
Subname: label, Subname: label,
Records: []string{r.GetTargetCombined()}, Records: []string{r.GetTargetCombined()},
} }
if r.Type == "TXT" {
zr.Records = []string{strings.Join(r.TxtStrings, "")}
}
zrs = append(zrs, zr) zrs = append(zrs, zr)
//keys[key] = &zr // This didn't work. //keys[key] = &zr // This didn't work.
keys[key] = &zrs[len(zrs)-1] // This does work. I don't know why. keys[key] = &zrs[len(zrs)-1] // This does work. I don't know why.

View file

@ -9,6 +9,7 @@ import (
"github.com/StackExchange/dnscontrol/v3/models" "github.com/StackExchange/dnscontrol/v3/models"
"github.com/StackExchange/dnscontrol/v3/pkg/diff" "github.com/StackExchange/dnscontrol/v3/pkg/diff"
"github.com/StackExchange/dnscontrol/v3/pkg/printer" "github.com/StackExchange/dnscontrol/v3/pkg/printer"
"github.com/StackExchange/dnscontrol/v3/pkg/txtutil"
"github.com/StackExchange/dnscontrol/v3/providers" "github.com/StackExchange/dnscontrol/v3/providers"
"github.com/miekg/dns/dnsutil" "github.com/miekg/dns/dnsutil"
) )
@ -23,13 +24,12 @@ Info required in `creds.json`:
func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) { func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
c := &desecProvider{} c := &desecProvider{}
c.creds.token = m["auth-token"] c.creds.token = m["auth-token"]
c.domainIndex = map[string]uint32{}
if c.creds.token == "" { if c.creds.token == "" {
return nil, fmt.Errorf("missing deSEC auth-token") return nil, fmt.Errorf("missing deSEC auth-token")
} }
if err := c.authenticate(); err != nil {
// Get a domain to validate authentication return nil, fmt.Errorf("authentication failed")
if err := c.fetchDomainList(); err != nil {
return nil, err
} }
return c, nil return c, nil
@ -99,6 +99,7 @@ func (c *desecProvider) GetZoneRecords(domain string) (models.Records, error) {
// Convert them to DNScontrol's native format: // Convert them to DNScontrol's native format:
existingRecords := []*models.RecordConfig{} existingRecords := []*models.RecordConfig{}
//spew.Dump(records)
for _, rr := range records { for _, rr := range records {
existingRecords = append(existingRecords, nativeToRecords(rr, domain)...) existingRecords = append(existingRecords, nativeToRecords(rr, domain)...)
} }
@ -107,7 +108,7 @@ func (c *desecProvider) GetZoneRecords(domain string) (models.Records, error) {
// EnsureDomainExists returns an error if domain doesn't exist. // EnsureDomainExists returns an error if domain doesn't exist.
func (c *desecProvider) EnsureDomainExists(domain string) error { func (c *desecProvider) EnsureDomainExists(domain string) error {
if err := c.fetchDomainList(); err != nil { if err := c.fetchDomain(domain); err != nil {
return err return err
} }
// domain already exists // domain already exists
@ -133,6 +134,7 @@ func PrepDesiredRecords(dc *models.DomainConfig, minTTL uint32) {
// confusing. // confusing.
dc.Punycode() dc.Punycode()
txtutil.SplitSingleLongTxt(dc.Records)
recordsToKeep := make([]*models.RecordConfig, 0, len(dc.Records)) recordsToKeep := make([]*models.RecordConfig, 0, len(dc.Records))
for _, rec := range dc.Records { for _, rec := range dc.Records {
if rec.Type == "ALIAS" { if rec.Type == "ALIAS" {

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/StackExchange/dnscontrol/v3/pkg/printer" "github.com/StackExchange/dnscontrol/v3/pkg/printer"
@ -15,7 +16,7 @@ const apiBase = "https://desec.io/api/v1"
// Api layer for desec // Api layer for desec
type desecProvider struct { type desecProvider struct {
domainIndex map[string]uint32 domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
nameserversNames []string nameserversNames []string
creds struct { creds struct {
tokenid string tokenid string
@ -58,24 +59,37 @@ type dnssecKey struct {
type errorResponse struct { type errorResponse struct {
Detail string `json:"detail"` Detail string `json:"detail"`
} }
type nonFieldError struct {
Errors []string `json:"non_field_errors"`
}
func (c *desecProvider) fetchDomainList() error { func (c *desecProvider) authenticate() error {
c.domainIndex = map[string]uint32{} endpoint := "/auth/account/"
var dr []domainObject var _, _, err = c.get(endpoint, "GET")
endpoint := "/domains/"
var bodyString, err = c.get(endpoint, "GET")
if err != nil { if err != nil {
return fmt.Errorf("Failed fetching domain list (deSEC): %s", err) return err
}
return nil
}
func (c *desecProvider) fetchDomain(domain string) error {
endpoint := fmt.Sprintf("/domains/%s", domain)
var dr domainObject
var bodyString, statuscode, err = c.get(endpoint, "GET")
if err != nil {
if statuscode == 404 {
return nil
}
return fmt.Errorf("Failed fetching domain: %s", err)
} }
err = json.Unmarshal(bodyString, &dr) err = json.Unmarshal(bodyString, &dr)
if err != nil { if err != nil {
return err return err
} }
for _, domain := range dr {
//We store the min ttl in the domain index //deSEC allows different minimum ttls per domain
//This will be used for validation and auto correction //we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
c.domainIndex[domain.Name] = domain.MinimumTTL c.domainIndex[dr.Name] = dr.MinimumTTL
}
return nil return nil
} }
@ -83,7 +97,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
endpoint := "/domains/%s/rrsets/" endpoint := "/domains/%s/rrsets/"
var rrs []rrResponse var rrs []rrResponse
var rrsNew []resourceRecord var rrsNew []resourceRecord
var bodyString, err = c.get(fmt.Sprintf(endpoint, domain), "GET") var bodyString, _, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
if err != nil { if err != nil {
return rrsNew, fmt.Errorf("Failed fetching records for domain %s (deSEC): %s", domain, err) return rrsNew, fmt.Errorf("Failed fetching records for domain %s (deSEC): %s", domain, err)
} }
@ -136,13 +150,13 @@ func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error {
func (c *desecProvider) deleteRR(domain, shortname, t string) error { func (c *desecProvider) deleteRR(domain, shortname, t string) error {
endpoint := fmt.Sprintf("/domains/%s/rrsets/%s/%s/", domain, shortname, t) endpoint := fmt.Sprintf("/domains/%s/rrsets/%s/%s/", domain, shortname, t)
if _, err := c.get(endpoint, "DELETE"); err != nil { if _, _, err := c.get(endpoint, "DELETE"); err != nil {
return fmt.Errorf("Failed delete RRset (deSEC): %v", err) return fmt.Errorf("Failed delete RRset (deSEC): %v", err)
} }
return nil return nil
} }
func (c *desecProvider) get(endpoint, method string) ([]byte, error) { func (c *desecProvider) get(endpoint, method string) ([]byte, int, error) {
retrycnt := 0 retrycnt := 0
retry: retry:
client := &http.Client{} client := &http.Client{}
@ -154,7 +168,7 @@ retry:
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return []byte{}, err return []byte{}, 0, err
} }
bodyString, _ := ioutil.ReadAll(resp.Body) bodyString, _ := ioutil.ReadAll(resp.Body)
@ -162,17 +176,38 @@ retry:
if resp.StatusCode > 299 { if resp.StatusCode > 299 {
if resp.StatusCode == 429 && retrycnt < 5 { if resp.StatusCode == 429 && retrycnt < 5 {
retrycnt++ retrycnt++
//we've got rate limiting and will try to get the Retry-After Header if this fails we fallback to sleep for 500ms max. 5 retries.
waitfor := resp.Header.Get("Retry-After")
if waitfor != "" {
wait, err := strconv.ParseInt(waitfor, 10, 64)
if err == nil {
if wait > 180 {
return []byte{}, 0, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
goto retry goto retry
} }
var errResp errorResponse var errResp errorResponse
var nfieldErrors []nonFieldError
err = json.Unmarshal(bodyString, &errResp) err = json.Unmarshal(bodyString, &errResp)
if err == nil { if err == nil {
return bodyString, fmt.Errorf("%s", errResp.Detail) return bodyString, resp.StatusCode, fmt.Errorf("%s", errResp.Detail)
} }
return bodyString, fmt.Errorf("HTTP status %d %s, the API does not provide more information", resp.StatusCode, resp.Status) err = json.Unmarshal(bodyString, &nfieldErrors)
if err == nil && len(nfieldErrors) > 0 {
if len(nfieldErrors[0].Errors) > 0 {
return bodyString, resp.StatusCode, fmt.Errorf("%s", nfieldErrors[0].Errors[0])
}
}
return bodyString, resp.StatusCode, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
} }
return bodyString, nil return bodyString, resp.StatusCode, nil
} }
func (c *desecProvider) post(endpoint, method string, payload []byte) ([]byte, error) { func (c *desecProvider) post(endpoint, method string, payload []byte) ([]byte, error) {
@ -202,15 +237,36 @@ retry:
if resp.StatusCode > 299 { if resp.StatusCode > 299 {
if resp.StatusCode == 429 && retrycnt < 5 { if resp.StatusCode == 429 && retrycnt < 5 {
retrycnt++ retrycnt++
//we've got rate limiting and will try to get the Retry-After Header if this fails we fallback to sleep for 500ms max. 5 retries.
waitfor := resp.Header.Get("Retry-After")
if waitfor != "" {
wait, err := strconv.ParseInt(waitfor, 10, 64)
if err == nil {
if wait > 180 {
return []byte{}, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
goto retry goto retry
} }
var errResp errorResponse var errResp errorResponse
var nfieldErrors []nonFieldError
err = json.Unmarshal(bodyString, &errResp) err = json.Unmarshal(bodyString, &errResp)
if err == nil { if err == nil {
return bodyString, fmt.Errorf("HTTP status %d %s details: %s", resp.StatusCode, resp.Status, errResp.Detail) return bodyString, fmt.Errorf("HTTP status %d %s details: %s", resp.StatusCode, resp.Status, errResp.Detail)
} }
return bodyString, fmt.Errorf("HTTP status %d %s, the API does not provide more information", resp.StatusCode, resp.Status) err = json.Unmarshal(bodyString, &nfieldErrors)
if err == nil && len(nfieldErrors) > 0 {
if len(nfieldErrors[0].Errors) > 0 {
return bodyString, fmt.Errorf("%s", nfieldErrors[0].Errors[0])
}
}
return bodyString, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
} }
//time.Sleep(334 * time.Millisecond) //time.Sleep(334 * time.Millisecond)
return bodyString, nil return bodyString, nil