dnscontrol/main.go
rbelnap 8cff3128ea add flag to delay between domains (#101)
* New flag: -delay=x where x is ms to delay between domains. This feature will probably be removed some day when a better rate-limiting solution can be engineered.
2017-08-05 07:49:21 -07:00

349 lines
9.4 KiB
Go

package main
import (
"bufio"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"strconv"
"strings"
"time"
"github.com/StackExchange/dnscontrol/models"
"github.com/StackExchange/dnscontrol/pkg/js"
"github.com/StackExchange/dnscontrol/pkg/nameservers"
"github.com/StackExchange/dnscontrol/pkg/normalize"
"github.com/StackExchange/dnscontrol/providers"
_ "github.com/StackExchange/dnscontrol/providers/_all"
"github.com/StackExchange/dnscontrol/providers/config"
)
//go:generate go run build/generate/generate.go
var jsFile = flag.String("js", "dnsconfig.js", "Javascript file containing dns config")
var credsFile = flag.String("creds", "creds.json", "Provider credentials JSON file")
var jsonFile = flag.String("json", "", "File containing intermediate JSON")
var jsonOutputPre = flag.String("debugrawjson", "", "Write JSON intermediate to this file pre-normalization.")
var jsonOutputPost = flag.String("debugjson", "", "During preview, write JSON intermediate to this file instead of stdout.")
var devMode = flag.Bool("dev", false, "Use helpers.js from disk instead of embedded")
var flagProviders = flag.String("providers", "", "Providers to enable (comma seperated list); default is all-but-bind. Specify 'all' for all (including bind)")
var domains = flag.String("domains", "", "Comma seperated list of domain names to include")
var interactive = flag.Bool("i", false, "Confirm or Exclude each correction before they run")
var delay = flag.Int64("d", 0, "delay between domains to avoid rate limits (in ms)")
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
flag.Parse()
command := flag.Arg(0)
if command == "version" {
fmt.Println(versionString())
return
}
var dnsConfig *models.DNSConfig
if *jsonFile != "" {
text, err := ioutil.ReadFile(*jsonFile)
if err != nil {
log.Fatalf("Error reading %v: %v\n", *jsonFile, err)
}
err = json.Unmarshal(text, &dnsConfig)
if err != nil {
log.Fatalf("Error parsing JSON in (%v): %v", *jsonFile, err)
}
} else if *jsFile != "" {
text, err := ioutil.ReadFile(*jsFile)
if err != nil {
log.Fatalf("Error reading %v: %v\n", *jsFile, err)
}
dnsConfig, err = js.ExecuteJavascript(string(text), *devMode)
if err != nil {
log.Fatalf("Error executing javasscript in (%v): %v", *jsFile, err)
}
}
if dnsConfig == nil {
log.Fatal("No config specified.")
}
if flag.NArg() != 1 {
fmt.Println(`Usage: dnscontrol [options] cmd
cmd:
preview: Show changed that would happen.
push: Make changes for real.
version: Print program version string.
print: Print compiled data.
create-domains: Pre-create domains in R53
`)
flag.PrintDefaults()
return
}
if *jsonOutputPre != "" {
dat, _ := json.MarshalIndent(dnsConfig, "", " ")
err := ioutil.WriteFile(*jsonOutputPre, dat, 0644)
if err != nil {
panic(err)
}
}
errs := normalize.NormalizeAndValidateConfig(dnsConfig)
if len(errs) > 0 {
fmt.Printf("%d Validation errors:\n", len(errs))
fatal := false
for _, err := range errs {
if _, ok := err.(normalize.Warning); ok {
fmt.Printf("WARNING: %s\n", err)
} else {
fatal = true
fmt.Printf("ERROR: %s\n", err)
}
}
if fatal {
log.Fatal("Exiting due to validation errors")
}
}
if command == "print" {
dat, _ := json.MarshalIndent(dnsConfig, "", " ")
if *jsonOutputPost == "" {
fmt.Println("While running JS:", string(dat))
} else {
err := ioutil.WriteFile(*jsonOutputPost, dat, 0644)
if err != nil {
panic(err)
}
}
return
}
providerConfigs, err := config.LoadProviderConfigs(*credsFile)
if err != nil {
log.Fatalf(err.Error())
}
nonDefaultProviders := []string{}
for name, vals := range providerConfigs {
// add "_exclude_from_defaults":"true" to a domain to exclude it from being run unless
// -providers=all or -providers=name
if vals["_exclude_from_defaults"] == "true" {
nonDefaultProviders = append(nonDefaultProviders, name)
}
}
registrars, err := providers.CreateRegistrars(dnsConfig, providerConfigs)
if err != nil {
log.Fatalf("Error creating registrars: %v\n", err)
}
dsps, err := providers.CreateDsps(dnsConfig, providerConfigs)
if err != nil {
log.Fatalf("Error creating dsps: %v\n", err)
}
fmt.Printf("Initialized %d registrars and %d dns service providers.\n", len(registrars), len(dsps))
anyErrors, totalCorrections := false, 0
switch command {
case "create-domains":
for _, domain := range dnsConfig.Domains {
fmt.Println("*** ", domain.Name)
for prov := range domain.DNSProviders {
dsp, ok := dsps[prov]
if !ok {
log.Fatalf("DSP %s not declared.", prov)
}
if creator, ok := dsp.(providers.DomainCreator); ok {
fmt.Println(" -", prov)
err := creator.EnsureDomainExists(domain.Name)
if err != nil {
fmt.Printf("Error creating domain: %s\n", err)
}
}
}
}
case "preview", "push":
DomainLoop:
for _, domain := range dnsConfig.Domains {
if !shouldRunDomain(domain.Name) {
continue
}
fmt.Printf("******************** Domain: %s\n", domain.Name)
nsList, err := nameservers.DetermineNameservers(domain, 0, dsps)
if err != nil {
log.Fatal(err)
}
domain.Nameservers = nsList
nameservers.AddNSRecords(domain)
for prov := range domain.DNSProviders {
dc, err := domain.Copy()
if err != nil {
log.Fatal(err)
}
shouldrun := shouldRunProvider(prov, dc, nonDefaultProviders)
statusLbl := ""
if !shouldrun {
statusLbl = "(skipping)"
}
fmt.Printf("----- DNS Provider: %s... %s", prov, statusLbl)
if !shouldrun {
fmt.Println()
continue
}
dsp, ok := dsps[prov]
if !ok {
log.Fatalf("DSP %s not declared.", prov)
}
corrections, err := dsp.GetDomainCorrections(dc)
if err != nil {
fmt.Println("ERROR")
anyErrors = true
fmt.Printf("Error getting corrections: %s\n", err)
continue DomainLoop
}
totalCorrections += len(corrections)
plural := "s"
if len(corrections) == 1 {
plural = ""
}
fmt.Printf("%d correction%s\n", len(corrections), plural)
anyErrors = printOrRunCorrections(corrections, command) || anyErrors
}
if run := shouldRunProvider(domain.Registrar, domain, nonDefaultProviders); !run {
continue
}
fmt.Printf("----- Registrar: %s\n", domain.Registrar)
reg, ok := registrars[domain.Registrar]
if !ok {
log.Fatalf("Registrar %s not declared.", reg)
}
if len(domain.Nameservers) == 0 && domain.Metadata["no_ns"] != "true" {
fmt.Printf("No nameservers declared; skipping registrar. Add {no_ns:'true'} to force.\n")
continue
}
dc, err := domain.Copy()
if err != nil {
log.Fatal(err)
}
corrections, err := reg.GetRegistrarCorrections(dc)
if err != nil {
fmt.Printf("Error getting corrections: %s\n", err)
anyErrors = true
continue
}
totalCorrections += len(corrections)
anyErrors = printOrRunCorrections(corrections, command) || anyErrors
time.Sleep(time.Duration(*delay) * time.Millisecond)
}
default:
log.Fatalf("Unknown command %s", command)
}
if os.Getenv("TEAMCITY_VERSION") != "" {
fmt.Fprintf(os.Stderr, "##teamcity[buildStatus status='SUCCESS' text='%d corrections']", totalCorrections)
}
fmt.Printf("Done. %d corrections.\n", totalCorrections)
if anyErrors {
os.Exit(1)
}
}
var reader = bufio.NewReader(os.Stdin)
func printOrRunCorrections(corrections []*models.Correction, command string) (anyErrors bool) {
anyErrors = false
if len(corrections) == 0 {
return anyErrors
}
for i, correction := range corrections {
fmt.Printf("#%d: %s\n", i+1, correction.Msg)
if command == "push" {
if *interactive {
fmt.Print("Run? (Y/n): ")
txt, err := reader.ReadString('\n')
run := true
if err != nil {
run = false
}
txt = strings.ToLower(strings.TrimSpace(txt))
if txt != "y" {
run = false
}
if !run {
fmt.Println("Skipping")
continue
}
}
err := correction.F()
if err != nil {
fmt.Println("FAILURE!", err)
anyErrors = true
} else {
fmt.Println("SUCCESS!")
}
}
}
return anyErrors
}
func shouldRunProvider(p string, dc *models.DomainConfig, nonDefaultProviders []string) bool {
if *flagProviders == "all" {
return true
}
if *flagProviders == "" {
for _, pr := range nonDefaultProviders {
if pr == p {
return false
}
}
return true
}
for _, prov := range strings.Split(*flagProviders, ",") {
if prov == p {
return true
}
}
return false
}
func shouldRunDomain(d string) bool {
if *domains == "" {
return true
}
for _, dom := range strings.Split(*domains, ",") {
if dom == d {
return true
}
}
return false
}
// Version management. 2 Goals:
// 1. Someone who just does "go get" has at least some information.
// 2. If built with build.sh, more specific build information gets put in.
// Update the number here manually each release, so at least we have a range for go-get people.
var (
SHA = ""
Version = "0.1.0"
BuildTime = ""
)
// printVersion prints the version banner.
func versionString() string {
var version string
if SHA != "" {
version = fmt.Sprintf("%s (%s)", Version, SHA)
} else {
version = fmt.Sprintf("%s-dev", Version) //no SHA. '0.x.y-dev' indeicates it is run form source without build script.
}
if BuildTime != "" {
i, err := strconv.ParseInt(BuildTime, 10, 64)
if err == nil {
tm := time.Unix(i, 0)
version += fmt.Sprintf(" built %s", tm.Format(time.RFC822))
}
}
return fmt.Sprintf("dnscontrol %s", version)
}