dnscontrol/providers/adguardhome/api.go
2025-07-09 12:06:34 -04:00

176 lines
4.3 KiB
Go

package adguardhome
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/StackExchange/dnscontrol/v4/pkg/printer"
)
type adguardHomeProvider struct {
username string
password string
host string
}
type requestParams map[string]any
type errorResponse struct {
Message string `json:"message"`
}
type rewriteEntry struct {
Domain string `json:"domain"`
Answer string `json:"answer"`
}
func (c *adguardHomeProvider) write(method, endpoint string, params requestParams) ([]byte, error) {
authHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(c.username+":"+c.password))
reqBodyJSON, err := json.Marshal(params)
if err != nil {
return []byte{}, err
}
client := &http.Client{}
req, _ := http.NewRequest(method, c.host+endpoint, bytes.NewBuffer(reqBodyJSON))
req.Header.Add("Authorization", authHeader)
req.Header.Add("Content-Type", "application/json")
retrycnt := 0
retry:
resp, err := client.Do(req)
if err != nil {
return []byte{}, err
}
bodyString, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
retrycnt++
if retrycnt == 5 {
return bodyString, errors.New("rate limit exceeded")
}
printer.Warnf("rate limiting.. waiting for %d second(s)\n", retrycnt*10)
time.Sleep(time.Second * time.Duration(retrycnt*10))
goto retry
}
var errResp errorResponse
err = json.Unmarshal(bodyString, &errResp)
if err == nil {
return bodyString, fmt.Errorf("AdguardHome API error: %s URL:%s%s ", errResp.Message, req.Host, req.URL.RequestURI())
}
if resp.StatusCode == http.StatusOK {
return bodyString, nil
} else {
return nil, errors.New(string(bodyString))
}
}
func (c *adguardHomeProvider) get(endpoint string) ([]byte, error) {
authHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(c.username+":"+c.password))
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, c.host+endpoint, nil)
req.Header.Add("Authorization", authHeader)
retrycnt := 0
retry:
resp, err := client.Do(req)
if err != nil {
return []byte{}, err
}
bodyString, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
retrycnt++
if retrycnt == 5 {
return bodyString, errors.New("rate limit exceeded")
}
printer.Warnf("rate limiting.. waiting for %d second(s)\n", retrycnt*10)
time.Sleep(time.Second * time.Duration(retrycnt*10))
goto retry
}
var errResp errorResponse
err = json.Unmarshal(bodyString, &errResp)
if err == nil {
return bodyString, fmt.Errorf("AdguardHome API error: %s URL:%s%s ", errResp.Message, req.Host, req.URL.RequestURI())
}
if resp.StatusCode == http.StatusOK {
return bodyString, nil
} else {
return nil, errors.New(string(bodyString))
}
}
func (c *adguardHomeProvider) createRecord(r rewriteEntry) error {
rec := requestParams{
"domain": r.Domain,
"answer": r.Answer,
}
if _, err := c.write(http.MethodPost, "/control/rewrite/add", rec); err != nil {
return fmt.Errorf("failed to create record (adguard home): %w", err)
}
return nil
}
func (c *adguardHomeProvider) deleteRecord(r rewriteEntry) error {
rec := requestParams{
"domain": r.Domain,
"answer": r.Answer,
}
if _, err := c.write(http.MethodPost, "/control/rewrite/delete", rec); err != nil {
return fmt.Errorf("failed to delete record (adguard home): %w", err)
}
return nil
}
func (c *adguardHomeProvider) modifyRecord(oldRe, newRe rewriteEntry) error {
rec := requestParams{
"target": oldRe,
"update": newRe,
}
if _, err := c.write(http.MethodPut, "/control/rewrite/update", rec); err != nil {
return fmt.Errorf("failed to update record (adguard home): %w", err)
}
return nil
}
func (c *adguardHomeProvider) getRecords(domain string) ([]rewriteEntry, error) {
bodyString, err := c.get("/control/rewrite/list")
if err != nil {
return nil, fmt.Errorf("failed to fetch records from adguardhome: %w", err)
}
var resp []rewriteEntry
err = json.Unmarshal(bodyString, &resp)
if err != nil {
return nil, fmt.Errorf("failed to parse records list from adguardhome: %w", err)
}
records := make([]rewriteEntry, 0, len(resp))
for _, r := range resp {
if !strings.HasSuffix(r.Domain, "."+domain) && r.Domain != domain {
continue
}
records = append(records, r)
}
return records, nil
}