diff --git a/main.go b/main.go index 281447f..42ea820 100644 --- a/main.go +++ b/main.go @@ -26,11 +26,16 @@ func main() { panic(err) } - err = syncRewrites(err, origin, replica) + err = syncRewrites(origin, replica) if err != nil { panic(err) } - err = syncFilters(err, origin, replica) + err = syncFilters(origin, replica) + if err != nil { + panic(err) + } + + err = syncServices(origin, replica) if err != nil { panic(err) } @@ -38,7 +43,25 @@ func main() { // POST http://192.168.2.207/control/dns_config {"protection_enabled":false} } -func syncFilters(err error, origin client.Client, replica client.Client) error { +func syncServices(origin client.Client, replica client.Client) error { + os, err := origin.Services() + if err != nil { + return err + } + rs, err := replica.Services() + if err != nil { + return err + } + + if !os.Equals(rs) { + if err := replica.SetServices(os); err != nil { + return err + } + } + return nil +} + +func syncFilters(origin client.Client, replica client.Client) error { of, err := origin.Filtering() if err != nil { return err @@ -91,7 +114,7 @@ func syncFilters(err error, origin client.Client, replica client.Client) error { return nil } -func syncRewrites(err error, origin client.Client, replica client.Client) error { +func syncRewrites(origin client.Client, replica client.Client) error { originRewrites, err := origin.RewriteList() if err != nil { return err diff --git a/pkg/client/client.go b/pkg/client/client.go index b2f3008..7703db1 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -47,6 +47,9 @@ type Client interface { ToggleSaveBrowsing(enable bool) error ToggleParental(enable bool) error ToggleSafeSearch(enable bool) error + + Services() (types.Services, error) + SetServices(services types.Services) error } type client struct { @@ -160,3 +163,15 @@ func (cl *client) ToggleFiltering(enabled bool, interval int) error { _, err := cl.client.R().EnableTrace().SetBody(&types.FilteringConfig{Enabled: enabled, Interval: interval}).Post("/filtering/config") return err } + +func (cl *client) Services() (types.Services, error) { + svcs := &types.Services{} + _, err := cl.client.R().EnableTrace().SetResult(svcs).Get("/blocked_services/list") + return *svcs, err +} + +func (cl *client) SetServices(services types.Services) error { + cl.log.With("services", len(services)).Info("Set services") + _, err := cl.client.R().EnableTrace().SetBody(&services).Post("/blocked_services/set") + return err +} diff --git a/pkg/types/types.go b/pkg/types/types.go index fa1a4c1..617bd95 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "sort" "strings" "time" ) @@ -73,8 +74,8 @@ type FilteringStatus struct { type UserRules []string -func (ur *UserRules) String() string { - return strings.Join(*ur, "\n") +func (ur UserRules) String() string { + return strings.Join(ur, "\n") } type FilteringConfig struct { @@ -109,3 +110,23 @@ func (fs *Filters) Merge(other Filters) (Filters, Filters) { return adds, removes } + +type Services []string + +func (s Services) Sort() { + sort.Strings(s) +} + +func (s Services) Equals(o Services) bool { + s.Sort() + o.Sort() + if len(s) != len(o) { + return false + } + for i, v := range s { + if v != o[i] { + return false + } + } + return true +}