mirror of
https://github.com/StackExchange/dnscontrol.git
synced 2025-02-25 08:02:58 +08:00
deSEC implement pagination (#1208)
* deSEC: Implement pagination for domain list #1177 * deSEC: add debug logging for pagination * deSEC: simplify get/post methods by allowing url / api endpoints as target * deSEC: implement pagination for getRecords function * deSEC: fix linter warnings * deSEC: replace domainIndexInitalized variable with checking if the domainIndex == nil * deSEC: add mutex for domainIndex Co-authored-by: Tom Limoncelli <tlimoncelli@stackoverflow.com>
This commit is contained in:
parent
997995eb4b
commit
2832746a47
2 changed files with 160 additions and 37 deletions
|
@ -24,13 +24,16 @@ Info required in `creds.json`:
|
|||
func NewDeSec(m map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
|
||||
c := &desecProvider{}
|
||||
c.creds.token = m["auth-token"]
|
||||
c.domainIndex = map[string]uint32{}
|
||||
if c.creds.token == "" {
|
||||
return nil, fmt.Errorf("missing deSEC auth-token")
|
||||
}
|
||||
if err := c.authenticate(); err != nil {
|
||||
return nil, fmt.Errorf("authentication failed")
|
||||
}
|
||||
//DomainIndex is used for corrections (minttl) and domain creation
|
||||
if err := c.initializeDomainIndex(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
@ -81,11 +84,13 @@ func (c *desecProvider) GetDomainCorrections(dc *models.DomainConfig) ([]*models
|
|||
models.PostProcessRecords(existing)
|
||||
clean := PrepFoundRecords(existing)
|
||||
var minTTL uint32
|
||||
c.mutex.Lock()
|
||||
if ttl, ok := c.domainIndex[dc.Name]; !ok {
|
||||
minTTL = 3600
|
||||
} else {
|
||||
minTTL = ttl
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
PrepDesiredRecords(dc, minTTL)
|
||||
return c.GenerateDomainCorrections(dc, clean)
|
||||
}
|
||||
|
@ -108,10 +113,9 @@ func (c *desecProvider) GetZoneRecords(domain string) (models.Records, error) {
|
|||
|
||||
// EnsureDomainExists returns an error if domain doesn't exist.
|
||||
func (c *desecProvider) EnsureDomainExists(domain string) error {
|
||||
if err := c.fetchDomain(domain); err != nil {
|
||||
return err
|
||||
}
|
||||
// domain already exists
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if _, ok := c.domainIndex[domain]; ok {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,10 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/StackExchange/dnscontrol/v3/pkg/printer"
|
||||
|
@ -16,14 +19,14 @@ const apiBase = "https://desec.io/api/v1"
|
|||
|
||||
// Api layer for desec
|
||||
type desecProvider struct {
|
||||
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
|
||||
nameserversNames []string
|
||||
creds struct {
|
||||
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
|
||||
creds struct {
|
||||
tokenid string
|
||||
token string
|
||||
user string
|
||||
password string
|
||||
}
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type domainObject struct {
|
||||
|
@ -71,37 +74,141 @@ func (c *desecProvider) authenticate() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
func (c *desecProvider) initializeDomainIndex() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.domainIndex != nil {
|
||||
return nil
|
||||
}
|
||||
endpoint := "/domains/"
|
||||
var bodyString, resp, err = c.get(endpoint, "GET")
|
||||
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
|
||||
//pagination is required
|
||||
links := c.convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["first"]
|
||||
printer.Debugf("initial endpoint %s\n", endpoint)
|
||||
for endpoint != "" {
|
||||
bodyString, resp, err = c.get(endpoint, "GET")
|
||||
if err != nil {
|
||||
if resp.StatusCode == 404 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
err = c.buildIndexFromResponse(bodyString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
links = c.convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["next"]
|
||||
printer.Debugf("next endpoint %s\n", endpoint)
|
||||
}
|
||||
printer.Debugf("Domain Index initilized with pagination (%d domains)\n", len(c.domainIndex))
|
||||
return nil //domainIndex was build using pagination without errors
|
||||
}
|
||||
|
||||
func (c *desecProvider) fetchDomain(domain string) error {
|
||||
endpoint := fmt.Sprintf("/domains/%s", domain)
|
||||
var dr domainObject
|
||||
var bodyString, statuscode, err = c.get(endpoint, "GET")
|
||||
if err != nil {
|
||||
if statuscode == 404 {
|
||||
//no pagination required
|
||||
if err != nil && resp.StatusCode != 400 {
|
||||
if resp.StatusCode == 404 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Failed fetching domain: %s", err)
|
||||
return fmt.Errorf("failed fetching domains: %s", err)
|
||||
}
|
||||
err = json.Unmarshal(bodyString, &dr)
|
||||
err = c.buildIndexFromResponse(bodyString)
|
||||
if err == nil {
|
||||
printer.Debugf("Domain Index initilized without pagination (%d domains)\n", len(c.domainIndex))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//buildIndexFromResponse takes the bodyString from initializeDomainIndex and builds the domainIndex
|
||||
func (c *desecProvider) buildIndexFromResponse(bodyString []byte) error {
|
||||
if c.domainIndex == nil {
|
||||
c.domainIndex = map[string]uint32{}
|
||||
}
|
||||
var dr []domainObject
|
||||
err := json.Unmarshal(bodyString, &dr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//deSEC allows different minimum ttls per domain
|
||||
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
|
||||
c.domainIndex[dr.Name] = dr.MinimumTTL
|
||||
for _, domain := range dr {
|
||||
//deSEC allows different minimum ttls per domain
|
||||
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
|
||||
c.domainIndex[domain.Name] = domain.MinimumTTL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/master/fetch_zone.py#L13)
|
||||
func (c *desecProvider) convertLinks(links string) map[string]string {
|
||||
mapping := make(map[string]string)
|
||||
printer.Debugf("Header: %s\n", links)
|
||||
for _, link := range strings.Split(links, ", ") {
|
||||
tmpurl := strings.Split(link, "; ")
|
||||
if len(tmpurl) != 2 {
|
||||
fmt.Printf("unexpected link header %s", link)
|
||||
continue
|
||||
}
|
||||
r := regexp.MustCompile(`rel="(.*)"`)
|
||||
matches := r.FindStringSubmatch(tmpurl[1])
|
||||
if len(matches) != 2 {
|
||||
fmt.Printf("unexpected label %s", tmpurl[1])
|
||||
continue
|
||||
}
|
||||
// mapping["$label"] = "$URL"
|
||||
//URL = https://desec.io/api/v1/domains/{domain}/rrsets/?cursor=:next_cursor
|
||||
mapping[matches[1]] = strings.TrimSuffix(strings.TrimPrefix(tmpurl[0], "<"), ">")
|
||||
}
|
||||
return mapping
|
||||
}
|
||||
|
||||
func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
|
||||
endpoint := "/domains/%s/rrsets/"
|
||||
var rrsNew []resourceRecord
|
||||
var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
|
||||
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
|
||||
//pagination required
|
||||
links := c.convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["first"]
|
||||
printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain))
|
||||
for endpoint != "" {
|
||||
bodyString, resp, err = c.get(endpoint, "GET")
|
||||
if err != nil {
|
||||
if resp.StatusCode == 404 {
|
||||
return rrsNew, nil
|
||||
}
|
||||
return rrsNew, fmt.Errorf("getRecords: failed fetching rrsets: %s", err)
|
||||
}
|
||||
tmp, err := generateRRSETfromResponse(bodyString)
|
||||
if err != nil {
|
||||
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err)
|
||||
}
|
||||
rrsNew = append(rrsNew, tmp...)
|
||||
links = c.convertLinks(resp.Header.Get("Link"))
|
||||
endpoint = links["next"]
|
||||
printer.Debugf("getRecords: next endpoint %s\n", endpoint)
|
||||
}
|
||||
printer.Debugf("Build rrset using pagination (%d rrs)\n", len(rrsNew))
|
||||
return rrsNew, nil //domainIndex was build using pagination without errors
|
||||
}
|
||||
//no pagination
|
||||
if err != nil {
|
||||
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %s", domain, err)
|
||||
}
|
||||
tmp, err := generateRRSETfromResponse(bodyString)
|
||||
if err != nil {
|
||||
return rrsNew, err
|
||||
}
|
||||
rrsNew = append(rrsNew, tmp...)
|
||||
printer.Debugf("Build rrset without pagination (%d rrs)\n", len(rrsNew))
|
||||
return rrsNew, nil
|
||||
}
|
||||
|
||||
//generateRRSETfromResponse takes the response rrset api calls and returns []resourceRecord
|
||||
func generateRRSETfromResponse(bodyString []byte) ([]resourceRecord, error) {
|
||||
var rrs []rrResponse
|
||||
var rrsNew []resourceRecord
|
||||
var bodyString, _, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
|
||||
if err != nil {
|
||||
return rrsNew, fmt.Errorf("Failed fetching records for domain %s (deSEC): %s", domain, err)
|
||||
}
|
||||
err = json.Unmarshal(bodyString, &rrs)
|
||||
err := json.Unmarshal(bodyString, &rrs)
|
||||
if err != nil {
|
||||
return rrsNew, err
|
||||
}
|
||||
|
@ -126,7 +233,7 @@ func (c *desecProvider) createDomain(domain string) error {
|
|||
var resp []byte
|
||||
var err error
|
||||
if resp, err = c.post(endpoint, "POST", byt); err != nil {
|
||||
return fmt.Errorf("Failed domain create (deSEC): %v", err)
|
||||
return fmt.Errorf("failed domain create (deSEC): %v", err)
|
||||
}
|
||||
dm := domainObject{}
|
||||
err = json.Unmarshal(resp, &dm)
|
||||
|
@ -143,7 +250,7 @@ func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error {
|
|||
endpoint := fmt.Sprintf("/domains/%s/rrsets/", domain)
|
||||
byt, _ := json.Marshal(rr)
|
||||
if _, err := c.post(endpoint, "PUT", byt); err != nil {
|
||||
return fmt.Errorf("Failed create RRset (deSEC): %v", err)
|
||||
return fmt.Errorf("failed create RRset (deSEC): %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -151,16 +258,22 @@ func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error {
|
|||
func (c *desecProvider) deleteRR(domain, shortname, t string) error {
|
||||
endpoint := fmt.Sprintf("/domains/%s/rrsets/%s/%s/", domain, shortname, t)
|
||||
if _, _, err := c.get(endpoint, "DELETE"); err != nil {
|
||||
return fmt.Errorf("Failed delete RRset (deSEC): %v", err)
|
||||
return fmt.Errorf("failed delete RRset (deSEC): %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *desecProvider) get(endpoint, method string) ([]byte, int, error) {
|
||||
func (c *desecProvider) get(target, method string) ([]byte, *http.Response, error) {
|
||||
retrycnt := 0
|
||||
var endpoint string
|
||||
if strings.Contains(target, "http") {
|
||||
endpoint = target
|
||||
} else {
|
||||
endpoint = apiBase + target
|
||||
}
|
||||
retry:
|
||||
client := &http.Client{}
|
||||
req, _ := http.NewRequest(method, apiBase+endpoint, nil)
|
||||
req, _ := http.NewRequest(method, endpoint, nil)
|
||||
q := req.URL.Query()
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.creds.token))
|
||||
|
||||
|
@ -168,7 +281,7 @@ retry:
|
|||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return []byte{}, 0, err
|
||||
return []byte{}, resp, err
|
||||
}
|
||||
|
||||
bodyString, _ := ioutil.ReadAll(resp.Body)
|
||||
|
@ -182,7 +295,7 @@ retry:
|
|||
wait, err := strconv.ParseInt(waitfor, 10, 64)
|
||||
if err == nil {
|
||||
if wait > 180 {
|
||||
return []byte{}, 0, fmt.Errorf("rate limiting exceeded")
|
||||
return []byte{}, resp, fmt.Errorf("rate limiting exceeded")
|
||||
}
|
||||
printer.Warnf("Rate limiting.. waiting for %s seconds", waitfor)
|
||||
time.Sleep(time.Duration(wait+1) * time.Second)
|
||||
|
@ -197,24 +310,30 @@ retry:
|
|||
var nfieldErrors []nonFieldError
|
||||
err = json.Unmarshal(bodyString, &errResp)
|
||||
if err == nil {
|
||||
return bodyString, resp.StatusCode, fmt.Errorf("%s", errResp.Detail)
|
||||
return bodyString, resp, fmt.Errorf("%s", errResp.Detail)
|
||||
}
|
||||
err = json.Unmarshal(bodyString, &nfieldErrors)
|
||||
if err == nil && len(nfieldErrors) > 0 {
|
||||
if len(nfieldErrors[0].Errors) > 0 {
|
||||
return bodyString, resp.StatusCode, fmt.Errorf("%s", nfieldErrors[0].Errors[0])
|
||||
return bodyString, resp, fmt.Errorf("%s", nfieldErrors[0].Errors[0])
|
||||
}
|
||||
}
|
||||
return bodyString, resp.StatusCode, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
|
||||
return bodyString, resp, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
|
||||
}
|
||||
return bodyString, resp.StatusCode, nil
|
||||
return bodyString, resp, nil
|
||||
}
|
||||
|
||||
func (c *desecProvider) post(endpoint, method string, payload []byte) ([]byte, error) {
|
||||
func (c *desecProvider) post(target, method string, payload []byte) ([]byte, error) {
|
||||
retrycnt := 0
|
||||
var endpoint string
|
||||
if strings.Contains(target, "http") {
|
||||
endpoint = target
|
||||
} else {
|
||||
endpoint = apiBase + target
|
||||
}
|
||||
retry:
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(method, apiBase+endpoint, bytes.NewReader(payload))
|
||||
req, err := http.NewRequest(method, endpoint, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue