diff --git a/main.go b/main.go index 42ea820..fda9163 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,9 @@ package main import ( - "github.com/bakito/adguardhome-sync/pkg/client" "os" + + "github.com/bakito/adguardhome-sync/pkg/client" ) const ( @@ -15,7 +16,6 @@ const ( ) func main() { - // Create a Resty Client origin, err := client.New(os.Getenv(envOriginApiURL), os.Getenv(envOriginUsername), os.Getenv(envOriginPassword)) if err != nil { @@ -26,6 +26,20 @@ func main() { panic(err) } + os, err := origin.Status() + if err != nil { + panic(err) + } + + rs, err := replica.Status() + if err != nil { + panic(err) + } + + if os.Version != rs.Version { + panic("Versions do not match") + } + err = syncRewrites(origin, replica) if err != nil { panic(err) @@ -40,7 +54,10 @@ func main() { panic(err) } - // POST http://192.168.2.207/control/dns_config {"protection_enabled":false} + err = syncClients(origin, replica) + if err != nil { + panic(err) + } } func syncServices(origin client.Client, replica client.Client) error { @@ -134,3 +151,27 @@ func syncRewrites(origin client.Client, replica client.Client) error { } return nil } + +func syncClients(origin client.Client, replica client.Client) error { + oc, err := origin.Clients() + if err != nil { + return err + } + rc, err := replica.Clients() + if err != nil { + return err + } + + a, u, r := rc.Merge(oc) + + if err = replica.AddClients(a...); err != nil { + return err + } + if err = replica.UpdateClients(u...); err != nil { + return err + } + if err = replica.DeleteClients(r...); err != nil { + return err + } + return nil +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 7703db1..1a54f87 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -2,11 +2,12 @@ package client import ( "fmt" + "net/url" + "github.com/bakito/adguardhome-sync/pkg/log" "github.com/bakito/adguardhome-sync/pkg/types" "github.com/go-resty/resty/v2" "go.uber.org/zap" - "net/url" ) var ( @@ -50,6 +51,11 @@ type Client interface { Services() (types.Services, error) SetServices(services types.Services) error + + Clients() (*types.Clients, error) + AddClients(client ...types.Client) error + UpdateClients(client ...types.Client) error + DeleteClients(client ...types.Client) error } type client struct { @@ -175,3 +181,42 @@ func (cl *client) SetServices(services types.Services) error { _, err := cl.client.R().EnableTrace().SetBody(&services).Post("/blocked_services/set") return err } + +func (cl *client) Clients() (*types.Clients, error) { + clients := &types.Clients{} + _, err := cl.client.R().EnableTrace().SetResult(clients).Get("/clients") + return clients, err +} + +func (cl *client) AddClients(clients ...types.Client) error { + for _, client := range clients { + cl.log.With("name", client.Name).Info("Add client") + _, err := cl.client.R().EnableTrace().SetBody(&client).Post("/clients/add") + if err != nil { + return err + } + } + return nil +} + +func (cl *client) UpdateClients(clients ...types.Client) error { + for _, client := range clients { + cl.log.With("name", client.Name).Info("Update client") + _, err := cl.client.R().EnableTrace().SetBody(&types.ClientUpdate{Name: client.Name, Data: client}).Post("/clients/update") + if err != nil { + return err + } + } + return nil +} + +func (cl *client) DeleteClients(clients ...types.Client) error { + for _, client := range clients { + cl.log.With("name", client.Name).Info("Delete client") + _, err := cl.client.R().EnableTrace().SetBody(&client).Post("/clients/delete") + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/types/types.go b/pkg/types/types.go index 617bd95..a15093a 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,6 +1,7 @@ package types import ( + "encoding/json" "fmt" "sort" "strings" @@ -120,11 +121,98 @@ func (s Services) Sort() { func (s Services) Equals(o Services) bool { s.Sort() o.Sort() - if len(s) != len(o) { + return equals(s, o) +} + +type Clients struct { + Clients []Client `json:"clients"` + AutoClients []struct { + IP string `json:"ip"` + Name string `json:"name"` + Source string `json:"source"` + WhoisInfo struct { + } `json:"whois_info"` + } `json:"auto_clients"` + SupportedTags []string `json:"supported_tags"` +} + +type Client struct { + Ids []string `json:"ids"` + Tags []string `json:"tags"` + BlockedServices []string `json:"blocked_services"` + Upstreams []string `json:"upstreams"` + + UseGlobalSettings bool `json:"use_global_settings"` + UseGlobalBlockedServices bool `json:"use_global_blocked_services"` + Name string `json:"name"` + FilteringEnabled bool `json:"filtering_enabled"` + ParentalEnabled bool `json:"parental_enabled"` + SafesearchEnabled bool `json:"safesearch_enabled"` + SafebrowsingEnabled bool `json:"safebrowsing_enabled"` + Disallowed bool `json:"disallowed"` + DisallowedRule string `json:"disallowed_rule"` +} + +func (cl *Client) Sort() { + sort.Strings(cl.Ids) + sort.Strings(cl.Tags) + sort.Strings(cl.BlockedServices) + sort.Strings(cl.Upstreams) +} + +func (cl *Client) Equal(o *Client) bool { + cl.Sort() + o.Sort() + + a, _ := json.Marshal(cl) + b, _ := json.Marshal(o) + return string(a) == string(b) +} + +func (clients *Clients) Merge(other *Clients) ([]Client, []Client, []Client) { + current := make(map[string]Client) + for _, client := range clients.Clients { + current[client.Name] = client + } + + expected := make(map[string]Client) + for _, client := range other.Clients { + expected[client.Name] = client + } + + var adds []Client + var removes []Client + var updates []Client + + for _, cl := range expected { + if oc, ok := current[cl.Name]; ok { + if !cl.Equal(&oc) { + updates = append(updates, cl) + } + delete(current, cl.Name) + } else { + adds = append(adds, cl) + } + } + + for _, rr := range current { + removes = append(removes, rr) + } + + return adds, updates, removes +} + +type ClientUpdate struct { + Name string `json:"name"` + Data Client `json:"data"` +} + +func equals(a []string, b []string) bool { + if len(a) != len(b) { return false } - for i, v := range s { - if v != o[i] { + for i, v := range a { + if v != b[i] { return false } }