handle clients

This commit is contained in:
bakito 2021-03-28 09:58:54 +02:00
parent b4c5380738
commit 0c17b13f96
No known key found for this signature in database
GPG key ID: FAF93C1C384DD6B4
3 changed files with 181 additions and 7 deletions

47
main.go
View file

@ -1,8 +1,9 @@
package main package main
import ( import (
"github.com/bakito/adguardhome-sync/pkg/client"
"os" "os"
"github.com/bakito/adguardhome-sync/pkg/client"
) )
const ( const (
@ -15,7 +16,6 @@ const (
) )
func main() { func main() {
// Create a Resty Client
origin, err := client.New(os.Getenv(envOriginApiURL), os.Getenv(envOriginUsername), os.Getenv(envOriginPassword)) origin, err := client.New(os.Getenv(envOriginApiURL), os.Getenv(envOriginUsername), os.Getenv(envOriginPassword))
if err != nil { if err != nil {
@ -26,6 +26,20 @@ func main() {
panic(err) panic(err)
} }
os, err := origin.Status()
if err != nil {
panic(err)
}
rs, err := replica.Status()
if err != nil {
panic(err)
}
if os.Version != rs.Version {
panic("Versions do not match")
}
err = syncRewrites(origin, replica) err = syncRewrites(origin, replica)
if err != nil { if err != nil {
panic(err) panic(err)
@ -40,7 +54,10 @@ func main() {
panic(err) panic(err)
} }
// POST http://192.168.2.207/control/dns_config {"protection_enabled":false} err = syncClients(origin, replica)
if err != nil {
panic(err)
}
} }
func syncServices(origin client.Client, replica client.Client) error { func syncServices(origin client.Client, replica client.Client) error {
@ -134,3 +151,27 @@ func syncRewrites(origin client.Client, replica client.Client) error {
} }
return nil return nil
} }
func syncClients(origin client.Client, replica client.Client) error {
oc, err := origin.Clients()
if err != nil {
return err
}
rc, err := replica.Clients()
if err != nil {
return err
}
a, u, r := rc.Merge(oc)
if err = replica.AddClients(a...); err != nil {
return err
}
if err = replica.UpdateClients(u...); err != nil {
return err
}
if err = replica.DeleteClients(r...); err != nil {
return err
}
return nil
}

View file

@ -2,11 +2,12 @@ package client
import ( import (
"fmt" "fmt"
"net/url"
"github.com/bakito/adguardhome-sync/pkg/log" "github.com/bakito/adguardhome-sync/pkg/log"
"github.com/bakito/adguardhome-sync/pkg/types" "github.com/bakito/adguardhome-sync/pkg/types"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"go.uber.org/zap" "go.uber.org/zap"
"net/url"
) )
var ( var (
@ -50,6 +51,11 @@ type Client interface {
Services() (types.Services, error) Services() (types.Services, error)
SetServices(services types.Services) error SetServices(services types.Services) error
Clients() (*types.Clients, error)
AddClients(client ...types.Client) error
UpdateClients(client ...types.Client) error
DeleteClients(client ...types.Client) error
} }
type client struct { type client struct {
@ -175,3 +181,42 @@ func (cl *client) SetServices(services types.Services) error {
_, err := cl.client.R().EnableTrace().SetBody(&services).Post("/blocked_services/set") _, err := cl.client.R().EnableTrace().SetBody(&services).Post("/blocked_services/set")
return err return err
} }
func (cl *client) Clients() (*types.Clients, error) {
clients := &types.Clients{}
_, err := cl.client.R().EnableTrace().SetResult(clients).Get("/clients")
return clients, err
}
func (cl *client) AddClients(clients ...types.Client) error {
for _, client := range clients {
cl.log.With("name", client.Name).Info("Add client")
_, err := cl.client.R().EnableTrace().SetBody(&client).Post("/clients/add")
if err != nil {
return err
}
}
return nil
}
func (cl *client) UpdateClients(clients ...types.Client) error {
for _, client := range clients {
cl.log.With("name", client.Name).Info("Update client")
_, err := cl.client.R().EnableTrace().SetBody(&types.ClientUpdate{Name: client.Name, Data: client}).Post("/clients/update")
if err != nil {
return err
}
}
return nil
}
func (cl *client) DeleteClients(clients ...types.Client) error {
for _, client := range clients {
cl.log.With("name", client.Name).Info("Delete client")
_, err := cl.client.R().EnableTrace().SetBody(&client).Post("/clients/delete")
if err != nil {
return err
}
}
return nil
}

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"encoding/json"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@ -120,11 +121,98 @@ func (s Services) Sort() {
func (s Services) Equals(o Services) bool { func (s Services) Equals(o Services) bool {
s.Sort() s.Sort()
o.Sort() o.Sort()
if len(s) != len(o) { return equals(s, o)
}
type Clients struct {
Clients []Client `json:"clients"`
AutoClients []struct {
IP string `json:"ip"`
Name string `json:"name"`
Source string `json:"source"`
WhoisInfo struct {
} `json:"whois_info"`
} `json:"auto_clients"`
SupportedTags []string `json:"supported_tags"`
}
type Client struct {
Ids []string `json:"ids"`
Tags []string `json:"tags"`
BlockedServices []string `json:"blocked_services"`
Upstreams []string `json:"upstreams"`
UseGlobalSettings bool `json:"use_global_settings"`
UseGlobalBlockedServices bool `json:"use_global_blocked_services"`
Name string `json:"name"`
FilteringEnabled bool `json:"filtering_enabled"`
ParentalEnabled bool `json:"parental_enabled"`
SafesearchEnabled bool `json:"safesearch_enabled"`
SafebrowsingEnabled bool `json:"safebrowsing_enabled"`
Disallowed bool `json:"disallowed"`
DisallowedRule string `json:"disallowed_rule"`
}
func (cl *Client) Sort() {
sort.Strings(cl.Ids)
sort.Strings(cl.Tags)
sort.Strings(cl.BlockedServices)
sort.Strings(cl.Upstreams)
}
func (cl *Client) Equal(o *Client) bool {
cl.Sort()
o.Sort()
a, _ := json.Marshal(cl)
b, _ := json.Marshal(o)
return string(a) == string(b)
}
func (clients *Clients) Merge(other *Clients) ([]Client, []Client, []Client) {
current := make(map[string]Client)
for _, client := range clients.Clients {
current[client.Name] = client
}
expected := make(map[string]Client)
for _, client := range other.Clients {
expected[client.Name] = client
}
var adds []Client
var removes []Client
var updates []Client
for _, cl := range expected {
if oc, ok := current[cl.Name]; ok {
if !cl.Equal(&oc) {
updates = append(updates, cl)
}
delete(current, cl.Name)
} else {
adds = append(adds, cl)
}
}
for _, rr := range current {
removes = append(removes, rr)
}
return adds, updates, removes
}
type ClientUpdate struct {
Name string `json:"name"`
Data Client `json:"data"`
}
func equals(a []string, b []string) bool {
if len(a) != len(b) {
return false return false
} }
for i, v := range s { for i, v := range a {
if v != o[i] { if v != b[i] {
return false return false
} }
} }