DESEC: Fix init (#3017)

Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
Jens Willemsens 2024-07-02 00:03:31 +02:00 committed by GitHub
parent 2f155ced46
commit c22f20db2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 92 additions and 85 deletions

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"strings"
"github.com/StackExchange/dnscontrol/v4/models"
"github.com/StackExchange/dnscontrol/v4/pkg/diff"
@ -21,18 +22,10 @@ Info required in `creds.json`:
// NewDeSec creates the provider.
func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
c := &desecProvider{}
c.creds.token = m["auth-token"]
if c.creds.token == "" {
c.token = strings.TrimSpace(m["auth-token"])
if c.token == "" {
return nil, fmt.Errorf("missing deSEC auth-token")
}
if err := c.authenticate(); err != nil {
return nil, fmt.Errorf("authentication failed")
}
//DomainIndex is used for corrections (minttl) and domain creation
if err := c.initializeDomainIndex(); err != nil {
return nil, err
}
return c, nil
}
@ -41,7 +34,7 @@ var features = providers.DocumentationNotes{
// See providers/capabilities.go for the entire list of capabilities.
providers.CanAutoDNSSEC: providers.Can("deSEC always signs all records. When trying to disable, a notice is printed."),
providers.CanGetZones: providers.Can(),
providers.CanConcur: providers.Cannot(),
providers.CanConcur: providers.Can(),
providers.CanUseAlias: providers.Unimplemented("Apex aliasing is supported via new SVCB and HTTPS record types. For details, check the deSEC docs."),
providers.CanUseCAA: providers.Can(),
providers.CanUseDS: providers.Can(),
@ -119,9 +112,13 @@ func (c *desecProvider) GetZoneRecords(domain string, meta map[string]string) (m
// EnsureZoneExists creates a zone if it does not exist
func (c *desecProvider) EnsureZoneExists(domain string) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if _, ok := c.domainIndex[domain]; ok {
_, ok, err := c.searchDomainIndex(domain)
if err != nil {
return err
}
if ok {
// Domain already exists
return nil
}
return c.createDomain(domain)
@ -155,14 +152,14 @@ func PrepDesiredRecords(dc *models.DomainConfig, minTTL uint32) {
// GetZoneRecordsCorrections returns a list of corrections that will turn existing records into dc.Records.
func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, existing models.Records) ([]*models.Correction, error) {
var minTTL uint32
c.mutex.Lock()
if ttl, ok := c.domainIndex[dc.Name]; !ok {
minTTL = 3600
} else {
minTTL = ttl
minTTL, ok, err := c.searchDomainIndex(dc.Name)
if err != nil {
return nil, err
}
c.mutex.Unlock()
if !ok {
minTTL = 3600
}
PrepDesiredRecords(dc, minTTL)
keysToUpdate, toReport, err := diff.NewCompat(dc).ChangedGroups(existing)
@ -250,9 +247,5 @@ func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, exist
// ListZones return all the zones in the account
func (c *desecProvider) ListZones() ([]string, error) {
var domains []string
for domain := range c.domainIndex {
domains = append(domains, domain)
}
return domains, nil
return c.listDomainIndex()
}

View file

@ -19,14 +19,9 @@ const apiBase = "https://desec.io/api/v1"
// Api layer for desec
type desecProvider struct {
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
creds struct {
tokenid string
token string
user string
password string
}
mutex sync.Mutex
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
domainIndexLock sync.Mutex
token string
}
type domainObject struct {
@ -66,86 +61,105 @@ type nonFieldError struct {
Errors []string `json:"non_field_errors"`
}
func (c *desecProvider) authenticate() error {
endpoint := "/auth/account/"
var _, resp, err = c.get(endpoint, "GET")
//restricted tokens are valid, but get 403 on /auth/account
//invalid tokens get 401
if resp.StatusCode == 403 {
return nil
}
if err != nil {
return err
// withDomainIndex checks if the domain index is initialized. If not, it's fetched from the deSEC API.
// Next, the provided readFn function is executed to extract data from the domain index.
func (c *desecProvider) withDomainIndex(readFn func(domainIndex map[string]uint32)) error {
// Lock index
c.domainIndexLock.Lock()
defer c.domainIndexLock.Unlock()
// Init index if needed
if c.domainIndex == nil {
printer.Debugf("Domain index not yet populated, fetching now\n")
var err error
c.domainIndex, err = c.fetchDomainIndex()
if err != nil {
return fmt.Errorf("failed to fetch domain index: %w", err)
}
}
// Execute handler on index
readFn(c.domainIndex)
return nil
}
func (c *desecProvider) initializeDomainIndex() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.domainIndex != nil {
return nil
}
// listDomainIndex lists all the available domains in the domain index
func (c *desecProvider) listDomainIndex() (domains []string, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
domains = make([]string, 0, len(domainIndex))
for domain := range domainIndex {
domains = append(domains, domain)
}
})
return
}
// searchDomainIndex performs a lookup to the domain index for the TTL of the domain
func (c *desecProvider) searchDomainIndex(domain string) (ttl uint32, found bool, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
ttl, found = domainIndex[domain]
})
return
}
func (c *desecProvider) fetchDomainIndex() (map[string]uint32, error) {
endpoint := "/domains/"
var domainIndex map[string]uint32
var bodyString, resp, err = c.get(endpoint, "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination is required
links := c.convertLinks(resp.Header.Get("Link"))
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("initial endpoint %s\n", endpoint)
for endpoint != "" {
bodyString, resp, err = c.get(endpoint, "GET")
if err != nil {
if resp.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
err = c.buildIndexFromResponse(bodyString)
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
links = c.convertLinks(resp.Header.Get("Link"))
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("next endpoint %s\n", endpoint)
}
printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex))
return nil //domainIndex was build using pagination without errors
printer.Debugf("Domain Index fetched with pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil //domainIndex was build using pagination without errors
}
//no pagination required
if err != nil && resp.StatusCode != 400 {
if resp.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed fetching domains: %s", err)
return nil, fmt.Errorf("failed fetching domains: %s", err)
}
err = c.buildIndexFromResponse(bodyString)
if err == nil {
printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex))
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return nil, err
}
return err
printer.Debugf("Domain Index fetched without pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil
}
// buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex
func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error {
if c.domainIndex == nil {
c.domainIndex = map[string]uint32{}
}
func appendDomainIndexFromResponse(domainIndex map[string]uint32, bodyString []byte) (map[string]uint32, error) {
var dr []domainObject
err := json.Unmarshal(bodyString, &dr)
if err != nil {
return err
return nil, err
}
if domainIndex == nil {
domainIndex = make(map[string]uint32, len(dr))
}
for _, domain := range dr {
//deSEC allows different minimum ttls per domain
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
c.domainIndex[domain.Name] = domain.MinimumTTL
domainIndex[domain.Name] = domain.MinimumTTL
}
return nil
return domainIndex, nil
}
// Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/main/fetch_zone.py#L13)
func (c *desecProvider) convertLinks(links string) map[string]string {
func convertLinks(links string) map[string]string {
mapping := make(map[string]string)
printer.Debugf("Header: %s\n", links)
for _, link := range strings.Split(links, ", ") {
@ -173,7 +187,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination required
links := c.convertLinks(resp.Header.Get("Link"))
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain))
for endpoint != "" {
@ -189,7 +203,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err)
}
rrsNew = append(rrsNew, tmp...)
links = c.convertLinks(resp.Header.Get("Link"))
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("getRecords: next endpoint %s\n", endpoint)
}
@ -282,7 +296,7 @@ retry:
client := &http.Client{}
req, _ := http.NewRequest(method, endpoint, nil)
q := req.URL.Query()
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
req.URL.RawQuery = q.Encode()
@ -304,12 +318,12 @@ retry:
if wait > 180 {
return []byte{}, resp, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}
@ -346,7 +360,7 @@ retry:
}
q := req.URL.Query()
if endpoint != "/auth/login/" {
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
}
req.Header.Set("Content-Type", "application/json")
@ -371,12 +385,12 @@ retry:
if wait > 180 {
return []byte{}, fmt.Errorf("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}