From c22f20db2cf7aa3607fe5a8d5c8c951995ff0cbf Mon Sep 17 00:00:00 2001 From: Jens Willemsens <6514515+JenswBE@users.noreply.github.com> Date: Tue, 2 Jul 2024 00:03:31 +0200 Subject: [PATCH] DESEC: Fix init (#3017) Co-authored-by: Tom Limoncelli --- providers/desec/desecProvider.go | 45 +++++------ providers/desec/protocol.go | 132 +++++++++++++++++-------------- 2 files changed, 92 insertions(+), 85 deletions(-) diff --git a/providers/desec/desecProvider.go b/providers/desec/desecProvider.go index 1134583f9..dee3f4fb1 100644 --- a/providers/desec/desecProvider.go +++ b/providers/desec/desecProvider.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "strings" "github.com/StackExchange/dnscontrol/v4/models" "github.com/StackExchange/dnscontrol/v4/pkg/diff" @@ -21,18 +22,10 @@ Info required in `creds.json`: // NewDeSec creates the provider. func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) { c := &desecProvider{} - c.creds.token = m["auth-token"] - if c.creds.token == "" { + c.token = strings.TrimSpace(m["auth-token"]) + if c.token == "" { return nil, fmt.Errorf("missing deSEC auth-token") } - if err := c.authenticate(); err != nil { - return nil, fmt.Errorf("authentication failed") - } - //DomainIndex is used for corrections (minttl) and domain creation - if err := c.initializeDomainIndex(); err != nil { - return nil, err - } - return c, nil } @@ -41,7 +34,7 @@ var features = providers.DocumentationNotes{ // See providers/capabilities.go for the entire list of capabilities. providers.CanAutoDNSSEC: providers.Can("deSEC always signs all records. When trying to disable, a notice is printed."), providers.CanGetZones: providers.Can(), - providers.CanConcur: providers.Cannot(), + providers.CanConcur: providers.Can(), providers.CanUseAlias: providers.Unimplemented("Apex aliasing is supported via new SVCB and HTTPS record types. For details, check the deSEC docs."), providers.CanUseCAA: providers.Can(), providers.CanUseDS: providers.Can(), @@ -119,9 +112,13 @@ func (c *desecProvider) GetZoneRecords(domain string, meta map[string]string) (m // EnsureZoneExists creates a zone if it does not exist func (c *desecProvider) EnsureZoneExists(domain string) error { - c.mutex.Lock() - defer c.mutex.Unlock() - if _, ok := c.domainIndex[domain]; ok { + _, ok, err := c.searchDomainIndex(domain) + if err != nil { + return err + } + + if ok { + // Domain already exists return nil } return c.createDomain(domain) @@ -155,14 +152,14 @@ func PrepDesiredRecords(dc *models.DomainConfig, minTTL uint32) { // GetZoneRecordsCorrections returns a list of corrections that will turn existing records into dc.Records. func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, existing models.Records) ([]*models.Correction, error) { - var minTTL uint32 - c.mutex.Lock() - if ttl, ok := c.domainIndex[dc.Name]; !ok { - minTTL = 3600 - } else { - minTTL = ttl + minTTL, ok, err := c.searchDomainIndex(dc.Name) + if err != nil { + return nil, err } - c.mutex.Unlock() + if !ok { + minTTL = 3600 + } + PrepDesiredRecords(dc, minTTL) keysToUpdate, toReport, err := diff.NewCompat(dc).ChangedGroups(existing) @@ -250,9 +247,5 @@ func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, exist // ListZones return all the zones in the account func (c *desecProvider) ListZones() ([]string, error) { - var domains []string - for domain := range c.domainIndex { - domains = append(domains, domain) - } - return domains, nil + return c.listDomainIndex() } diff --git a/providers/desec/protocol.go b/providers/desec/protocol.go index f056352b5..af5b3980d 100644 --- a/providers/desec/protocol.go +++ b/providers/desec/protocol.go @@ -19,14 +19,9 @@ const apiBase = "https://desec.io/api/v1" // Api layer for desec type desecProvider struct { - domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl) - creds struct { - tokenid string - token string - user string - password string - } - mutex sync.Mutex + domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl) + domainIndexLock sync.Mutex + token string } type domainObject struct { @@ -66,86 +61,105 @@ type nonFieldError struct { Errors []string `json:"non_field_errors"` } -func (c *desecProvider) authenticate() error { - endpoint := "/auth/account/" - var _, resp, err = c.get(endpoint, "GET") - //restricted tokens are valid, but get 403 on /auth/account - //invalid tokens get 401 - if resp.StatusCode == 403 { - return nil - } - if err != nil { - return err +// withDomainIndex checks if the domain index is initialized. If not, it's fetched from the deSEC API. +// Next, the provided readFn function is executed to extract data from the domain index. +func (c *desecProvider) withDomainIndex(readFn func(domainIndex map[string]uint32)) error { + // Lock index + c.domainIndexLock.Lock() + defer c.domainIndexLock.Unlock() + + // Init index if needed + if c.domainIndex == nil { + printer.Debugf("Domain index not yet populated, fetching now\n") + var err error + c.domainIndex, err = c.fetchDomainIndex() + if err != nil { + return fmt.Errorf("failed to fetch domain index: %w", err) + } } + + // Execute handler on index + readFn(c.domainIndex) return nil } -func (c *desecProvider) initializeDomainIndex() error { - c.mutex.Lock() - defer c.mutex.Unlock() - if c.domainIndex != nil { - return nil - } + +// listDomainIndex lists all the available domains in the domain index +func (c *desecProvider) listDomainIndex() (domains []string, err error) { + err = c.withDomainIndex(func(domainIndex map[string]uint32) { + domains = make([]string, 0, len(domainIndex)) + for domain := range domainIndex { + domains = append(domains, domain) + } + }) + return +} + +// searchDomainIndex performs a lookup to the domain index for the TTL of the domain +func (c *desecProvider) searchDomainIndex(domain string) (ttl uint32, found bool, err error) { + err = c.withDomainIndex(func(domainIndex map[string]uint32) { + ttl, found = domainIndex[domain] + }) + return +} + +func (c *desecProvider) fetchDomainIndex() (map[string]uint32, error) { endpoint := "/domains/" + var domainIndex map[string]uint32 var bodyString, resp, err = c.get(endpoint, "GET") if resp.StatusCode == 400 && resp.Header.Get("Link") != "" { //pagination is required - links := c.convertLinks(resp.Header.Get("Link")) + links := convertLinks(resp.Header.Get("Link")) endpoint = links["first"] printer.Debugf("initial endpoint %s\n", endpoint) for endpoint != "" { bodyString, resp, err = c.get(endpoint, "GET") if err != nil { - if resp.StatusCode == 404 { - return nil - } - return fmt.Errorf("failed fetching domains: %s", err) + return nil, fmt.Errorf("failed fetching domains: %s", err) } - err = c.buildIndexFromResponse(bodyString) + domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString) if err != nil { - return fmt.Errorf("failed fetching domains: %s", err) + return nil, fmt.Errorf("failed fetching domains: %s", err) } - links = c.convertLinks(resp.Header.Get("Link")) + links = convertLinks(resp.Header.Get("Link")) endpoint = links["next"] printer.Debugf("next endpoint %s\n", endpoint) } - printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex)) - return nil //domainIndex was build using pagination without errors + printer.Debugf("Domain Index fetched with pagination (%d domains)\n", len(domainIndex)) + return domainIndex, nil //domainIndex was build using pagination without errors } //no pagination required if err != nil && resp.StatusCode != 400 { - if resp.StatusCode == 404 { - return nil - } - return fmt.Errorf("failed fetching domains: %s", err) + return nil, fmt.Errorf("failed fetching domains: %s", err) } - err = c.buildIndexFromResponse(bodyString) - if err == nil { - printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex)) + domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString) + if err != nil { + return nil, err } - return err + printer.Debugf("Domain Index fetched without pagination (%d domains)\n", len(domainIndex)) + return domainIndex, nil } -// buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex -func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error { - if c.domainIndex == nil { - c.domainIndex = map[string]uint32{} - } +func appendDomainIndexFromResponse(domainIndex map[string]uint32, bodyString []byte) (map[string]uint32, error) { var dr []domainObject err := json.Unmarshal(bodyString, &dr) if err != nil { - return err + return nil, err + } + + if domainIndex == nil { + domainIndex = make(map[string]uint32, len(dr)) } for _, domain := range dr { //deSEC allows different minimum ttls per domain //we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors. - c.domainIndex[domain.Name] = domain.MinimumTTL + domainIndex[domain.Name] = domain.MinimumTTL } - return nil + return domainIndex, nil } // Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/main/fetch_zone.py#L13) -func (c *desecProvider) convertLinks(links string) map[string]string { +func convertLinks(links string) map[string]string { mapping := make(map[string]string) printer.Debugf("Header: %s\n", links) for _, link := range strings.Split(links, ", ") { @@ -173,7 +187,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) { var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET") if resp.StatusCode == 400 && resp.Header.Get("Link") != "" { //pagination required - links := c.convertLinks(resp.Header.Get("Link")) + links := convertLinks(resp.Header.Get("Link")) endpoint = links["first"] printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain)) for endpoint != "" { @@ -189,7 +203,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) { return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err) } rrsNew = append(rrsNew, tmp...) - links = c.convertLinks(resp.Header.Get("Link")) + links = convertLinks(resp.Header.Get("Link")) endpoint = links["next"] printer.Debugf("getRecords: next endpoint %s\n", endpoint) } @@ -282,7 +296,7 @@ retry: client := &http.Client{} req, _ := http.NewRequest(method, endpoint, nil) q := req.URL.Query() - req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token)) + req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token)) req.URL.RawQuery = q.Encode() @@ -304,12 +318,12 @@ retry: if wait > 180 { return []byte{}, resp, fmt.Errorf("rate limiting exceeded") } - printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor) + printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor) time.Sleep(time.Duration(wait+1) * time.Second) goto retry } } - printer.Warnf("Rate limiting.. waiting for 500 milliseconds") + printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n") time.Sleep(500 * time.Millisecond) goto retry } @@ -346,7 +360,7 @@ retry: } q := req.URL.Query() if endpoint != "/auth/login/" { - req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token)) + req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token)) } req.Header.Set("Content-Type", "application/json") @@ -371,12 +385,12 @@ retry: if wait > 180 { return []byte{}, fmt.Errorf("rate limiting exceeded") } - printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor) + printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor) time.Sleep(time.Duration(wait+1) * time.Second) goto retry } } - printer.Warnf("Rate limiting.. waiting for 500 milliseconds") + printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n") time.Sleep(500 * time.Millisecond) goto retry }