dnscontrol/providers/desec/protocol.go
2024-12-12 14:07:11 -05:00

414 lines
13 KiB
Go

package desec
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/StackExchange/dnscontrol/v4/pkg/printer"
)
const apiBase = "https://desec.io/api/v1"
// Api layer for desec
type desecProvider struct {
domainIndex map[string]uint32 //stores the minimum ttl of each domain. (key = domain and value = ttl)
domainIndexLock sync.Mutex
token string
}
type domainObject struct {
Created time.Time `json:"created,omitempty"`
Keys []dnssecKey `json:"keys,omitempty"`
MinimumTTL uint32 `json:"minimum_ttl,omitempty"`
Name string `json:"name,omitempty"`
Published time.Time `json:"published,omitempty"`
}
type resourceRecord struct {
Subname string `json:"subname"`
Records []string `json:"records"`
TTL uint32 `json:"ttl,omitempty"`
Type string `json:"type"`
Target string `json:"-"`
}
type rrResponse struct {
resourceRecord
Created time.Time `json:"created"`
Domain string `json:"domain"`
Name string `json:"name"`
}
type dnssecKey struct {
Dnskey string `json:"dnskey"`
Ds []string `json:"ds"`
Flags int `json:"flags"`
Keytype string `json:"keytype"`
}
type errorResponse struct {
Detail string `json:"detail"`
}
type nonFieldError struct {
Errors []string `json:"non_field_errors"`
}
// withDomainIndex checks if the domain index is initialized. If not, it's fetched from the deSEC API.
// Next, the provided readFn function is executed to extract data from the domain index.
func (c *desecProvider) withDomainIndex(readFn func(domainIndex map[string]uint32)) error {
// Lock index
c.domainIndexLock.Lock()
defer c.domainIndexLock.Unlock()
// Init index if needed
if c.domainIndex == nil {
printer.Debugf("Domain index not yet populated, fetching now\n")
var err error
c.domainIndex, err = c.fetchDomainIndex()
if err != nil {
return fmt.Errorf("failed to fetch domain index: %w", err)
}
}
// Execute handler on index
readFn(c.domainIndex)
return nil
}
// listDomainIndex lists all the available domains in the domain index
func (c *desecProvider) listDomainIndex() (domains []string, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
domains = make([]string, 0, len(domainIndex))
for domain := range domainIndex {
domains = append(domains, domain)
}
})
return
}
// searchDomainIndex performs a lookup to the domain index for the TTL of the domain
func (c *desecProvider) searchDomainIndex(domain string) (ttl uint32, found bool, err error) {
err = c.withDomainIndex(func(domainIndex map[string]uint32) {
ttl, found = domainIndex[domain]
})
return
}
func (c *desecProvider) fetchDomainIndex() (map[string]uint32, error) {
endpoint := "/domains/"
var domainIndex map[string]uint32
var bodyString, resp, err = c.get(endpoint, "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination is required
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("initial endpoint %s\n", endpoint)
for endpoint != "" {
bodyString, resp, err = c.get(endpoint, "GET")
if err != nil {
return nil, fmt.Errorf("failed fetching domains: %w", err)
}
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return nil, fmt.Errorf("failed fetching domains: %w", err)
}
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("next endpoint %s\n", endpoint)
}
printer.Debugf("Domain Index fetched with pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil //domainIndex was build using pagination without errors
}
//no pagination required
if err != nil && resp.StatusCode != 400 {
return nil, fmt.Errorf("failed fetching domains: %w", err)
}
domainIndex, err = appendDomainIndexFromResponse(domainIndex, bodyString)
if err != nil {
return nil, err
}
printer.Debugf("Domain Index fetched without pagination (%d domains)\n", len(domainIndex))
return domainIndex, nil
}
func appendDomainIndexFromResponse(domainIndex map[string]uint32, bodyString []byte) (map[string]uint32, error) {
var dr []domainObject
err := json.Unmarshal(bodyString, &dr)
if err != nil {
return nil, err
}
if domainIndex == nil {
domainIndex = make(map[string]uint32, len(dr))
}
for _, domain := range dr {
//deSEC allows different minimum ttls per domain
//we store the actual minimum ttl to use it in desecProvider.go GetDomainCorrections() to enforce the minimum ttl and avoid api errors.
domainIndex[domain.Name] = domain.MinimumTTL
}
return domainIndex, nil
}
// Parses the Link Header into a map (https://github.com/desec-io/desec-tools/blob/main/fetch_zone.py#L13)
func convertLinks(links string) map[string]string {
mapping := make(map[string]string)
printer.Debugf("Header: %s\n", links)
for _, link := range strings.Split(links, ", ") {
tmpurl := strings.Split(link, "; ")
if len(tmpurl) != 2 {
printer.Printf("unexpected link header %s", link)
continue
}
r := regexp.MustCompile(`rel="(.*)"`)
matches := r.FindStringSubmatch(tmpurl[1])
if len(matches) != 2 {
printer.Printf("unexpected label %s", tmpurl[1])
continue
}
// mapping["$label"] = "$URL"
//URL = https://desec.io/api/v1/domains/{domain}/rrsets/?cursor=:next_cursor
mapping[matches[1]] = strings.TrimSuffix(strings.TrimPrefix(tmpurl[0], "<"), ">")
}
return mapping
}
func (c *desecProvider) getRecords(domain string) ([]resourceRecord, error) {
endpoint := "/domains/%s/rrsets/"
var rrsNew []resourceRecord
var bodyString, resp, err = c.get(fmt.Sprintf(endpoint, domain), "GET")
if resp.StatusCode == 400 && resp.Header.Get("Link") != "" {
//pagination required
links := convertLinks(resp.Header.Get("Link"))
endpoint = links["first"]
printer.Debugf("getRecords: initial endpoint %s\n", fmt.Sprintf(endpoint, domain))
for endpoint != "" {
bodyString, resp, err = c.get(endpoint, "GET")
if err != nil {
if resp.StatusCode == 404 {
return rrsNew, nil
}
return rrsNew, fmt.Errorf("getRecords: failed fetching rrsets: %w", err)
}
tmp, err := generateRRSETfromResponse(bodyString)
if err != nil {
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %w", domain, err)
}
rrsNew = append(rrsNew, tmp...)
links = convertLinks(resp.Header.Get("Link"))
endpoint = links["next"]
printer.Debugf("getRecords: next endpoint %s\n", endpoint)
}
printer.Debugf("Build rrset using pagination (%d rrs)\n", len(rrsNew))
return rrsNew, nil //domainIndex was build using pagination without errors
}
//no pagination
if err != nil {
return rrsNew, fmt.Errorf("failed fetching records for domain %s (deSEC): %w", domain, err)
}
tmp, err := generateRRSETfromResponse(bodyString)
if err != nil {
return rrsNew, err
}
rrsNew = append(rrsNew, tmp...)
printer.Debugf("Build rrset without pagination (%d rrs)\n", len(rrsNew))
return rrsNew, nil
}
// generateRRSETfromResponse takes the response rrset api calls and returns []resourceRecord
func generateRRSETfromResponse(bodyString []byte) ([]resourceRecord, error) {
var rrs []rrResponse
var rrsNew []resourceRecord
err := json.Unmarshal(bodyString, &rrs)
if err != nil {
return rrsNew, err
}
// deSEC returns round robin records as array but dnsconfig expects single entries for each record
// we will create one object per record except of TXT records which are handled as array of string by dnscontrol aswell.
for i := range rrs {
tmp := resourceRecord{
TTL: rrs[i].TTL,
Type: rrs[i].Type,
Subname: rrs[i].Subname,
Records: rrs[i].Records,
}
rrsNew = append(rrsNew, tmp)
}
return rrsNew, nil
}
func (c *desecProvider) createDomain(domain string) error {
endpoint := "/domains/"
pl := domainObject{Name: domain}
byt, _ := json.Marshal(pl)
var resp []byte
var err error
if resp, err = c.post(endpoint, "POST", byt); err != nil {
return fmt.Errorf("failed domain create (deSEC): %w", err)
}
dm := domainObject{}
err = json.Unmarshal(resp, &dm)
if err != nil {
return err
}
printer.Printf("To enable DNSSEC validation for your domain, make sure to convey the DS record(s) to your registrar:\n")
printer.Printf("%+q", dm.Keys)
return nil
}
// upsertRR will create or override the RRSet with the provided resource record.
func (c *desecProvider) upsertRR(rr []resourceRecord, domain string) error {
endpoint := fmt.Sprintf("/domains/%s/rrsets/", domain)
byt, _ := json.Marshal(rr)
if _, err := c.post(endpoint, "PUT", byt); err != nil {
return fmt.Errorf("failed create RRset (deSEC): %w", err)
}
return nil
}
// Uncomment this function in case of using it
// It was commented out to satisfy `staticcheck` warnings about unused code
//func (c *desecProvider) deleteRR(domain, shortname, t string) error {
// endpoint := fmt.Sprintf("/domains/%s/rrsets/%s/%s/", domain, shortname, t)
// if _, _, err := c.get(endpoint, "DELETE"); err != nil {
// return fmt.Errorf("failed delete RRset (deSEC): %w", err)
// }
// return nil
//}
func (c *desecProvider) get(target, method string) ([]byte, *http.Response, error) {
retrycnt := 0
var endpoint string
if strings.Contains(target, "http") {
endpoint = target
} else {
endpoint = apiBase + target
}
retry:
client := &http.Client{}
req, _ := http.NewRequest(method, endpoint, nil)
q := req.URL.Query()
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return []byte{}, resp, err
}
bodyString, _ := io.ReadAll(resp.Body)
// Got error from API ?
if resp.StatusCode > 299 {
if resp.StatusCode == 429 && retrycnt < 5 {
retrycnt++
//we've got rate limiting and will try to get the Retry-After Header if this fails we fallback to sleep for 500ms max. 5 retries.
waitfor := resp.Header.Get("Retry-After")
if waitfor != "" {
wait, err := strconv.ParseInt(waitfor, 10, 64)
if err == nil {
if wait > 180 {
return []byte{}, resp, errors.New("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}
var errResp errorResponse
var nfieldErrors []nonFieldError
err = json.Unmarshal(bodyString, &errResp)
if err == nil {
return bodyString, resp, errors.New(errResp.Detail)
}
err = json.Unmarshal(bodyString, &nfieldErrors)
if err == nil && len(nfieldErrors) > 0 {
if len(nfieldErrors[0].Errors) > 0 {
return bodyString, resp, errors.New(nfieldErrors[0].Errors[0])
}
}
return bodyString, resp, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
}
return bodyString, resp, nil
}
func (c *desecProvider) post(target, method string, payload []byte) ([]byte, error) {
retrycnt := 0
var endpoint string
if strings.Contains(target, "http") {
endpoint = target
} else {
endpoint = apiBase + target
}
retry:
client := &http.Client{}
req, err := http.NewRequest(method, endpoint, bytes.NewReader(payload))
if err != nil {
return []byte{}, err
}
q := req.URL.Query()
if endpoint != "/auth/login/" {
req.Header.Add("Authorization", fmt.Sprintf("Token %s", c.token))
}
req.Header.Set("Content-Type", "application/json")
req.URL.RawQuery = q.Encode()
resp, err := client.Do(req)
if err != nil {
return []byte{}, err
}
bodyString, _ := io.ReadAll(resp.Body)
// Got error from API ?
if resp.StatusCode > 299 {
if resp.StatusCode == 429 && retrycnt < 5 {
retrycnt++
//we've got rate limiting and will try to get the Retry-After Header if this fails we fallback to sleep for 500ms max. 5 retries.
waitfor := resp.Header.Get("Retry-After")
if waitfor != "" {
wait, err := strconv.ParseInt(waitfor, 10, 64)
if err == nil {
if wait > 180 {
return []byte{}, errors.New("rate limiting exceeded")
}
printer.Warnf("Rate limiting.. waiting for %s seconds\n", waitfor)
time.Sleep(time.Duration(wait+1) * time.Second)
goto retry
}
}
printer.Warnf("Rate limiting.. waiting for 500 milliseconds\n")
time.Sleep(500 * time.Millisecond)
goto retry
}
var errResp errorResponse
var nfieldErrors []nonFieldError
err = json.Unmarshal(bodyString, &errResp)
if err == nil {
return bodyString, fmt.Errorf("HTTP status %d %s details: %s", resp.StatusCode, resp.Status, errResp.Detail)
}
err = json.Unmarshal(bodyString, &nfieldErrors)
if err == nil && len(nfieldErrors) > 0 {
if len(nfieldErrors[0].Errors) > 0 {
return bodyString, errors.New(nfieldErrors[0].Errors[0])
}
}
return bodyString, fmt.Errorf("HTTP status %s Body: %s, the API does not provide more information", resp.Status, bodyString)
}
//time.Sleep(334 * time.Millisecond)
return bodyString, nil
}