diff --git a/cmd/run.go b/cmd/run.go index b1b441b..5733d3c 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -1,7 +1,6 @@ package cmd import ( - "github.com/bakito/adguardhome-sync/pkg/client" "github.com/bakito/adguardhome-sync/pkg/log" "github.com/bakito/adguardhome-sync/pkg/sync" "github.com/bakito/adguardhome-sync/pkg/types" @@ -24,21 +23,7 @@ var doCmd = &cobra.Command{ return } - origin, err := client.New(cfg.Origin) - if err != nil { - logger.Error(err) - return - } - replica, err := client.New(cfg.Replica) - if err != nil { - logger.Error(err) - return - } - err = sync.Sync(origin, replica) - if err != nil { - logger.Error(err) - return - } + sync.Sync(cfg) }, } diff --git a/pkg/client/client.go b/pkg/client/client.go index bd5a93f..b9c5852 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -65,7 +65,7 @@ type Client interface { ToggleParental(enable bool) error ToggleSafeSearch(enable bool) error - Services() (types.Services, error) + Services() (*types.Services, error) SetServices(services types.Services) error Clients() (*types.Clients, error) @@ -190,10 +190,10 @@ func (cl *client) ToggleFiltering(enabled bool, interval int) error { return err } -func (cl *client) Services() (types.Services, error) { +func (cl *client) Services() (*types.Services, error) { svcs := &types.Services{} _, err := cl.client.R().EnableTrace().SetResult(svcs).Get("/blocked_services/list") - return *svcs, err + return svcs, err } func (cl *client) SetServices(services types.Services) error { diff --git a/pkg/sync/sync.go b/pkg/sync/sync.go index 977ee6d..870c29a 100644 --- a/pkg/sync/sync.go +++ b/pkg/sync/sync.go @@ -3,73 +3,118 @@ package sync import ( "github.com/bakito/adguardhome-sync/pkg/client" "github.com/bakito/adguardhome-sync/pkg/log" + "github.com/bakito/adguardhome-sync/pkg/types" + "go.uber.org/zap" ) // Sync config from origin to replica -func Sync(origin client.Client, replica client.Client) error { - - l := log.GetLogger("sync").With("from", origin.Host(), "to", replica.Host()) - l.Info("Start sync") - - os, err := origin.Status() +func Sync(cfg *types.Config) { + l := log.GetLogger("sync") + oc, err := client.New(cfg.Origin) if err != nil { - return err + l.With("error", err, "url", cfg.Origin.URL).Error("Error creating origin client") + return } - rs, err := replica.Status() + l = l.With("from", oc.Host()) + + o := &origin{} + + o.status, err = oc.Status() if err != nil { - return err + l.With("error", err).Error("Error getting origin status") + return } - if os.Version != rs.Version { - panic("Versions do not match") - } - - err = syncRewrites(origin, replica) + o.rewrites, err = oc.RewriteList() if err != nil { - return err + l.With("error", err).Error("Error getting origin rewrites") + return } - err = syncFilters(origin, replica) + + o.services, err = oc.Services() if err != nil { - return err + l.With("error", err).Error("Error getting origin services") + return } - err = syncServices(origin, replica) + o.filters, err = oc.Filtering() if err != nil { - return err + l.With("error", err).Error("Error getting origin filters") + return + } + o.clients, err = oc.Clients() + if err != nil { + l.With("error", err).Error("Error getting origin clients") + return } - if err = syncClients(origin, replica); err != nil { - return err + replicas := cfg.UniqueReplicas() + for _, replica := range replicas { + syncTo(l, o, replica) } - - l.Info("Sync done") - return nil } -func syncServices(origin client.Client, replica client.Client) error { - os, err := origin.Services() +func syncTo(l *zap.SugaredLogger, o *origin, replica types.AdGuardInstance) { + + rc, err := client.New(replica) if err != nil { - return err + l.With("error", err, "url", replica.URL).Error("Error creating replica client") } + + rl := l.With("to", rc.Host()) + rl.Info("Start sync") + + rs, err := rc.Status() + if err != nil { + l.With("error", err).Error("Error getting replica status") + return + } + + if o.status.Version != rs.Version { + l.With("originVersion", o.status.Version, "replicaVersion", rs.Version).Warn("Versions do not match") + } + + err = syncRewrites(o.rewrites, rc) + if err != nil { + l.With("error", err).Error("Error syncing rewrites") + return + } + err = syncFilters(o.filters, rc) + if err != nil { + l.With("error", err).Error("Error syncing filters") + return + } + + err = syncServices(o.services, rc) + if err != nil { + l.With("error", err).Error("Error syncing services") + return + } + + if err = syncClients(o.clients, rc); err != nil { + l.With("error", err).Error("Error syncing clients") + return + } + + rl.Info("Sync done") +} + +func syncServices(os *types.Services, replica client.Client) error { rs, err := replica.Services() if err != nil { return err } if !os.Equals(rs) { - if err := replica.SetServices(os); err != nil { + if err := replica.SetServices(*os); err != nil { return err } } return nil } -func syncFilters(origin client.Client, replica client.Client) error { - of, err := origin.Filtering() - if err != nil { - return err - } +func syncFilters(of *types.FilteringStatus, replica client.Client) error { rf, err := replica.Filtering() if err != nil { return err @@ -118,17 +163,14 @@ func syncFilters(origin client.Client, replica client.Client) error { return nil } -func syncRewrites(origin client.Client, replica client.Client) error { - originRewrites, err := origin.RewriteList() - if err != nil { - return err - } +func syncRewrites(or *types.RewriteEntries, replica client.Client) error { + replicaRewrites, err := replica.RewriteList() if err != nil { return err } - a, r := replicaRewrites.Merge(originRewrites) + a, r := replicaRewrites.Merge(or) if err = replica.AddRewriteEntries(a...); err != nil { return err @@ -139,11 +181,7 @@ func syncRewrites(origin client.Client, replica client.Client) error { return nil } -func syncClients(origin client.Client, replica client.Client) error { - oc, err := origin.Clients() - if err != nil { - return err - } +func syncClients(oc *types.Clients, replica client.Client) error { rc, err := replica.Clients() if err != nil { return err @@ -162,3 +200,11 @@ func syncClients(origin client.Client, replica client.Client) error { } return nil } + +type origin struct { + status *types.Status + rewrites *types.RewriteEntries + services *types.Services + filters *types.FilteringStatus + clients *types.Clients +} diff --git a/pkg/types/types.go b/pkg/types/types.go index 6f8de7e..866ebfc 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -9,8 +9,26 @@ import ( ) type Config struct { - Origin AdGuardInstance `json:"origin" yaml:"origin"` - Replica AdGuardInstance `json:"replica" yaml:"replica"` + Origin AdGuardInstance `json:"origin" yaml:"origin"` + Replica *AdGuardInstance `json:"replica,omitempty" yaml:"replica,omitempty"` + Replicas []AdGuardInstance `json:"replicas,omitempty" yaml:"replicas,omitempty"` + Cron string `json:"cron,omitempty" yaml:"cron,omitempty"` +} + +func (cfg *Config) UniqueReplicas() []AdGuardInstance { + dedup := make(map[string]AdGuardInstance) + if cfg.Replica != nil { + dedup[cfg.Replica.Key()] = *cfg.Replica + } + for _, replica := range cfg.Replicas { + dedup[replica.Key()] = replica + } + + var r []AdGuardInstance + for _, replica := range dedup { + r = append(r, replica) + } + return r } type AdGuardInstance struct { @@ -21,6 +39,10 @@ type AdGuardInstance struct { InsecureSkipVerify bool `json:"insecureSkipVerify" yaml:"insecureSkipVerify"` } +func (i *AdGuardInstance) Key() string { + return fmt.Sprintf("%s%s", i.URL, i.APIPath) +} + type Status struct { DNSAddresses []string `json:"dns_addresses"` DNSPort int `json:"dns_port"` @@ -131,10 +153,10 @@ func (s Services) Sort() { sort.Strings(s) } -func (s Services) Equals(o Services) bool { +func (s *Services) Equals(o *Services) bool { s.Sort() o.Sort() - return equals(s, o) + return equals(*s, *o) } type Clients struct {