mirror of
https://github.com/StackExchange/dnscontrol.git
synced 2025-12-09 13:46:07 +08:00
DESEC: Fix init (#3017)
Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
parent
2f155ced46
commit
c22f20db2c
2 changed files with 92 additions and 85 deletions
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/StackExchange/dnscontrol/v4/models"
|
||||
"github.com/StackExchange/dnscontrol/v4/pkg/diff"
|
||||
|
|
@ -21,18 +22,10 @@ Info required in `creds.json`:
|
|||
// NewDeSec creates the provider.
|
||||
func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
|
||||
c := &desecProvider{}
|
||||
c.creds.token = m["auth-token"]
|
||||
if c.creds.token == "" {
|
||||
c.token = strings.TrimSpace(m["auth-token"])
|
||||
if c.token == "" {
|
||||
return nil, fmt.Errorf("missing deSEC auth-token")
|
||||
}
|
||||
if err := c.authenticate(); err != nil {
|
||||
return nil, fmt.Errorf("authentication failed")
|
||||
}
|
||||
//DomainIndex is used for corrections (minttl) and domain creation
|
||||
if err := c.initializeDomainIndex(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
|
@ -41,7 +34,7 @@ var features = providers.DocumentationNotes{
|
|||
// See providers/capabilities.go for the entire list of capabilities.
|
||||
providers.CanAutoDNSSEC: providers.Can("deSEC always signs all records. When trying to disable, a notice is printed."),
|
||||
providers.CanGetZones: providers.Can(),
|
||||
providers.CanConcur: providers.Cannot(),
|
||||
providers.CanConcur: providers.Can(),
|
||||
providers.CanUseAlias: providers.Unimplemented("Apex aliasing is supported via new SVCB and HTTPS record types. For details, check the deSEC docs."),
|
||||
providers.CanUseCAA: providers.Can(),
|
||||
providers.CanUseDS: providers.Can(),
|
||||
|
|
@ -119,9 +112,13 @@ func (c *desecProvider) GetZoneRecords(domain string, meta map[string]string) (m
|
|||
|
||||
// EnsureZoneExists creates a zone if it does not exist
|
||||
func (c *desecProvider) EnsureZoneExists(domain string) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if _, ok := c.domainIndex[domain]; ok {
|
||||
_, ok, err := c.searchDomainIndex(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
// Domain already exists
|
||||
return nil
|
||||
}
|
||||
return c.createDomain(domain)
|
||||
|
|
@ -155,14 +152,14 @@ func PrepDesiredRecords(dc *models.DomainConfig, minTTL uint32) {
|
|||
|
||||
// GetZoneRecordsCorrections returns a list of corrections that will turn existing records into dc.Records.
|
||||
func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, existing models.Records) ([]*models.Correction, error) {
|
||||
var minTTL uint32
|
||||
c.mutex.Lock()
|
||||
if ttl, ok := c.domainIndex[dc.Name]; !ok {
|
||||
minTTL = 3600
|
||||
} else {
|
||||
minTTL = ttl
|
||||
minTTL, ok, err := c.searchDomainIndex(dc.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
if !ok {
|
||||
minTTL = 3600
|
||||
}
|
||||
|
||||
PrepDesiredRecords(dc, minTTL)
|
||||
|
||||
keysToUpdate, toReport, err := diff.NewCompat(dc).ChangedGroups(existing)
|
||||
|
|
@ -250,9 +247,5 @@ func (c *desecProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, exist
|
|||
|
||||
// ListZones return all the zones in the account
|
||||
func (c *desecProvider) ListZones() ([]string, error) {
|
||||
var domains []string
|
||||
for domain := range c.domainIndex {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
return domains, nil
|
||||
return c.listDomainIndex()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,14 +19,9 @@ const apiBase = "https://desec.io/api/v1"
|
|||
|
||||
// Api layer for desec
|
||||
type desecProvider struct {
|
||||
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
|
||||
creds struct {
|
||||
tokenid string
|
||||
token string
|
||||
user string
|
||||
password string
|
||||
}
|
||||
mutex sync.Mutex
|
||||
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
|
||||
domainIndexLock sync.Mutex
|
||||
token string
|
||||
}
|
||||
|
||||
type domainObject struct {
|
||||
|
|
@ -66,86 +61,105 @@ type nonFieldError struct {
|
|||
Errors []string `json:"non_field_errors"`
|
||||
}
|
||||
|
||||
func (c *desecProvider) authenticate() error {
|
||||
endpoint := "/auth/account/"
|
||||
var _, resp, err = c.get(endpoint, "GET")
|
||||
//restricted tokens are valid, but get 403 on /auth/account
|
||||
//invalid tokens get 401
|
||||
if resp.StatusCode == 403 {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
// withDomainIndex checks if the domain index is initialized. If not, it's fetched from the deSEC API.
|
||||
// Next, the provided readFn function is executed to extract data from the domain index.
|
||||
func (c *desecProvider) withDomainIndex(readFn func(domainIndex map[string]uint32)) error {
|
||||
// Lock index
|
||||
c.domainIndexLock.Lock()
|
||||
defer c.domainIndexLock.Unlock()
|
||||
|
||||
// Init index if needed
|
||||
if c.domainIndex == nil {
|
||||
printer.Debugf("Domain index not yet populated, fetching now\n")
|
||||
var err error
|
||||
c.domainIndex, err = c.fetchDomainIndex()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch domain index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute handler on index
|
||||
readFn(c.domainIndex)
|
||||
return nil
|
||||
}
|
||||
func (c *desecProvider) initializeDomainIndex() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.domainIndex != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// listDomainIndex lists all the available domains in the domain index
|
||||
func (c *desecProvider) listDomainIndex() (domains []string, err error) {
|
||||
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
|
||||
domains = make([]string, 0, len(domainIndex))
|
||||
for domain := range domainIndex {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// searchDomainIndex performs a lookup to the domain index for the TTL of the domain
|
||||
func (c *desecProvider) searchDomainIndex(domain string) (ttl uint32, found bool, err error) {
|
||||
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
|
||||
ttl, found = domainIndex[domain]
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *desecProvider) fetchDomainIndex() (map[string]uint32, error) {
|
||||
endpoint := "/domains/"
|
||||
var domainIndex map[string]uint32
|
||||
var bodyString, resp, err = c.get(endpoint, "GET")
|
||||
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
|
||||
//pagination is required
|
||||
links := c.convertLinks(resp.Header.Get("Link"))
|
||||
links := convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["first"]
|
||||
printer.Debugf("initial endpoint %s\n", endpoint)
|
||||
for endpoint != "" {
|
||||
bodyString, resp, err = c.get(endpoint, "GET")
|
||||
if err != nil {
|
||||
if resp.StatusCode == 404 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
return nil, fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
err = c.buildIndexFromResponse(bodyString)
|
||||
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
return nil, fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
links = c.convertLinks(resp.Header.Get("Link"))
|
||||
links = convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["next"]
|
||||
printer.Debugf("next endpoint %s\n", endpoint)
|
||||
}
|
||||
printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex))
|
||||
return nil //domainIndex was build using pagination without errors
|
||||
printer.Debugf("Domain Index fetched with pagination (%d domains)\n", len(domainIndex))
|
||||
return domainIndex, nil //domainIndex was build using pagination without errors
|
||||
}
|
||||
|
||||
//no pagination required
|
||||
if err != nil && resp.StatusCode != 400 {
|
||||
if resp.StatusCode == 404 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
return nil, fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
err = c.buildIndexFromResponse(bodyString)
|
||||
if err == nil {
|
||||
printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex))
|
||||
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return err
|
||||
printer.Debugf("Domain Index fetched without pagination (%d domains)\n", len(domainIndex))
|
||||
return domainIndex, nil
|
||||
}
|
||||
|
||||
// buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex
|
||||
func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error {
|
||||
if c.domainIndex == nil {
|
||||
c.domainIndex = map[string]uint32{}
|
||||
}
|
||||
func appendDomainIndexFromResponse(domainIndex map[string]uint32, bodyString []byte) (map[string]uint32, error) {
|
||||
var dr []domainObject
|
||||
err := json.Unmarshal(bodyString, &dr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if domainIndex == nil {
|
||||
domainIndex = make(map[string]uint32, len(dr))
|
||||
}
|
||||
for _, domain := range dr {
|
||||
//deSEC allows different minimum ttls per domain
|
||||
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
|
||||
c.domainIndex[domain.Name] = domain.MinimumTTL
|
||||
domainIndex[domain.Name] = domain.MinimumTTL
|
||||
}
|
||||
return nil
|
||||
return domainIndex, nil
|
||||
}
|
||||
|
||||
// Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/main/fetch_zone.py#L13)
|
||||
func (c *desecProvider) convertLinks(links string) map[string]string {
|
||||
func convertLinks(links string) map[string]string {
|
||||
mapping := make(map[string]string)
|
||||
printer.Debugf("Header: %s\n", links)
|
||||
for _, link := range strings.Split(links, ", ") {
|
||||
|
|
@ -173,7 +187,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
|
|||
var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
|
||||
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
|
||||
//pagination required
|
||||
links := c.convertLinks(resp.Header.Get("Link"))
|
||||
links := convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["first"]
|
||||
printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain))
|
||||
for endpoint != "" {
|
||||
|
|
@ -189,7 +203,7 @@ func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
|
|||
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err)
|
||||
}
|
||||
rrsNew = append(rrsNew, tmp...)
|
||||
links = c.convertLinks(resp.Header.Get("Link"))
|
||||
links = convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["next"]
|
||||
printer.Debugf("getRecords: next endpoint %s\n", endpoint)
|
||||
}
|
||||
|
|
@ -282,7 +296,7 @@ retry:
|
|||
client := &http.Client{}
|
||||
req, _ := http.NewRequest(method, endpoint, nil)
|
||||
q := req.URL.Query()
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
|
||||
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
|
|
@ -304,12 +318,12 @@ retry:
|
|||
if wait > 180 {
|
||||
return []byte{}, resp, fmt.Errorf("rate limiting exceeded")
|
||||
}
|
||||
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
|
||||
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
|
||||
time.Sleep(time.Duration(wait+1) * time.Second)
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
|
||||
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
goto retry
|
||||
}
|
||||
|
|
@ -346,7 +360,7 @@ retry:
|
|||
}
|
||||
q := req.URL.Query()
|
||||
if endpoint != "/auth/login/" {
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
|
|
@ -371,12 +385,12 @@ retry:
|
|||
if wait > 180 {
|
||||
return []byte{}, fmt.Errorf("rate limiting exceeded")
|
||||
}
|
||||
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
|
||||
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
|
||||
time.Sleep(time.Duration(wait+1) * time.Second)
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
printer.Warnf("Rate limiting.. waiting for 500 milliseconds")
|
||||
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
goto retry
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue