diff --git a/cmd/root.go b/cmd/root.go index f13d563..6e21160 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,7 +8,7 @@ import ( "github.com/bakito/adguardhome-sync/pkg/log" "github.com/bakito/adguardhome-sync/pkg/types" - homedir "github.com/mitchellh/go-homedir" + "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "github.com/spf13/viper" ) diff --git a/pkg/client/client.go b/pkg/client/client.go index d35147e..b445dad 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -20,7 +20,6 @@ var ( // New create a new client func New(config types.AdGuardInstance) (Client, error) { - var apiURL string if config.APIPath == "" { apiURL = fmt.Sprintf("%s/control", config.URL) @@ -124,7 +123,6 @@ func (cl *client) Status() (*types.Status, error) { status := &types.Status{} err := cl.doGet(cl.client.R().EnableTrace().SetResult(status), "status") return status, err - } func (cl *client) RewriteList() (*types.RewriteEntries, error) { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 4f1cb25..cf01682 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -6,36 +6,289 @@ import ( "net/http/httptest" "path/filepath" - "github.com/bakito/adguardhome-sync/pkg/client" - "github.com/bakito/adguardhome-sync/pkg/types" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + + "github.com/bakito/adguardhome-sync/pkg/client" + "github.com/bakito/adguardhome-sync/pkg/types" ) var _ = Describe("Client", func() { var ( cl client.Client + ts *httptest.Server ) + AfterEach(func() { + if ts != nil { + ts.Close() + } + }) + + Context("Host", func() { + It("should read the current host", func() { + cl, _ := client.New(types.AdGuardInstance{URL: "https://foo.bar:3000"}) + host := cl.Host() + Ω(host).Should(Equal("foo.bar:3000")) + }) + }) Context("Filtering", func() { - It("should reade filtering status", func() { - cl = Serve("filtering-status.json") - _, err := cl.Filtering() + It("should read filtering status", func() { + ts, cl = ClientGet("filtering-status.json", "/filtering/status") + fs, err := cl.Filtering() + Ω(err).ShouldNot(HaveOccurred()) + Ω(fs.Enabled).Should(BeTrue()) + Ω(fs.Filters).Should(HaveLen(2)) + }) + It("should enable protection", func() { + ts, cl = ClientPost("/filtering/config", `{"enabled":true,"interval":123}`) + err := cl.ToggleFiltering(true, 123) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should disable protection", func() { + ts, cl = ClientPost("/filtering/config", `{"enabled":false,"interval":123}`) + err := cl.ToggleFiltering(false, 123) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should call RefreshFilters", func() { + ts, cl = ClientPost("/filtering/refresh", `{"whitelist":true}`) + err := cl.RefreshFilters(true) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should add Filters", func() { + ts, cl = ClientPost("/filtering/add_url", + `{"id":0,"enabled":false,"url":"foo","name":"","rules_count":0,"whitelist":true}`, + `{"id":0,"enabled":false,"url":"bar","name":"","rules_count":0,"whitelist":true}`, + ) + err := cl.AddFilters(true, types.Filter{URL: "foo"}, types.Filter{URL: "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should update Filters", func() { + ts, cl = ClientPost("/filtering/set_url", + `{"url":"foo","data":{"id":0,"enabled":false,"url":"foo","name":"","rules_count":0,"whitelist":true},"whitelist":true}`, + `{"url":"bar","data":{"id":0,"enabled":false,"url":"bar","name":"","rules_count":0,"whitelist":true},"whitelist":true}`, + ) + err := cl.UpdateFilters(true, types.Filter{URL: "foo"}, types.Filter{URL: "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should delete Filters", func() { + ts, cl = ClientPost("/filtering/remove_url", + `{"id":0,"enabled":false,"url":"foo","name":"","rules_count":0,"whitelist":true}`, + `{"id":0,"enabled":false,"url":"bar","name":"","rules_count":0,"whitelist":true}`, + ) + err := cl.DeleteFilters(true, types.Filter{URL: "foo"}, types.Filter{URL: "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("CustomRules", func() { + It("should set SetCustomRules", func() { + ts, cl = ClientPost("/filtering/set_rules", `foo +bar`) + err := cl.SetCustomRules([]string{"foo", "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("Status", func() { + It("should read status", func() { + ts, cl = ClientGet("status.json", "/status") + fs, err := cl.Status() + Ω(err).ShouldNot(HaveOccurred()) + Ω(fs.DNSAddresses).Should(HaveLen(1)) + Ω(fs.DNSAddresses[0]).Should(Equal("192.168.1.2")) + Ω(fs.Version).Should(Equal("v0.105.2")) + }) + }) + + Context("RewriteList", func() { + It("should read RewriteList", func() { + ts, cl = ClientGet("rewrite-list.json", "/rewrite/list") + rwl, err := cl.RewriteList() + Ω(err).ShouldNot(HaveOccurred()) + Ω(*rwl).Should(HaveLen(2)) + }) + It("should add RewriteList", func() { + ts, cl = ClientPost("/rewrite/add", `{"domain":"foo","answer":"foo"}`, `{"domain":"bar","answer":"bar"}`) + err := cl.AddRewriteEntries(types.RewriteEntry{Answer: "foo", Domain: "foo"}, types.RewriteEntry{Answer: "bar", Domain: "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should delete RewriteList", func() { + ts, cl = ClientPost("/rewrite/delete", `{"domain":"foo","answer":"foo"}`, `{"domain":"bar","answer":"bar"}`) + err := cl.DeleteRewriteEntries(types.RewriteEntry{Answer: "foo", Domain: "foo"}, types.RewriteEntry{Answer: "bar", Domain: "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("SafeBrowsing", func() { + It("should read safebrowsing status", func() { + ts, cl = ClientGet("safebrowsing-status.json", "/safebrowsing/status") + sb, err := cl.SafeBrowsing() + Ω(err).ShouldNot(HaveOccurred()) + Ω(sb).Should(BeTrue()) + }) + It("should enable safebrowsing", func() { + ts, cl = ClientPost("/safebrowsing/enable", "") + err := cl.ToggleSafeBrowsing(true) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should disable safebrowsing", func() { + ts, cl = ClientPost("/safebrowsing/disable", "") + err := cl.ToggleSafeBrowsing(false) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("SafeSearch", func() { + It("should read safesearch status", func() { + ts, cl = ClientGet("safesearch-status.json", "/safesearch/status") + ss, err := cl.SafeSearch() + Ω(err).ShouldNot(HaveOccurred()) + Ω(ss).Should(BeTrue()) + }) + It("should enable safesearch", func() { + ts, cl = ClientPost("/safesearch/enable", "") + err := cl.ToggleSafeSearch(true) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should disable safesearch", func() { + ts, cl = ClientPost("/safesearch/disable", "") + err := cl.ToggleSafeSearch(false) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("Parental", func() { + It("should read parental status", func() { + ts, cl = ClientGet("parental-status.json", "/parental/status") + p, err := cl.Parental() + Ω(err).ShouldNot(HaveOccurred()) + Ω(p).Should(BeTrue()) + }) + It("should enable parental", func() { + ts, cl = ClientPost("/parental/enable", "") + err := cl.ToggleParental(true) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should disable parental", func() { + ts, cl = ClientPost("/parental/disable", "") + err := cl.ToggleParental(false) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("Protection", func() { + It("should enable protection", func() { + ts, cl = ClientPost("/dns_config", `{"protection_enabled":true}`) + err := cl.ToggleProtection(true) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should disable protection", func() { + ts, cl = ClientPost("/dns_config", `{"protection_enabled":false}`) + err := cl.ToggleProtection(false) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("Services", func() { + It("should read Services", func() { + ts, cl = ClientGet("blockedservices-list.json", "/blocked_services/list") + s, err := cl.Services() + Ω(err).ShouldNot(HaveOccurred()) + Ω(s).Should(HaveLen(2)) + }) + It("should set Services", func() { + ts, cl = ClientPost("/blocked_services/set", `["foo","bar"]`) + err := cl.SetServices([]string{"foo", "bar"}) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("Clients", func() { + It("should read Clients", func() { + ts, cl = ClientGet("clients.json", "/clients") + c, err := cl.Clients() + Ω(err).ShouldNot(HaveOccurred()) + Ω(c.Clients).Should(HaveLen(2)) + }) + It("should add Clients", func() { + ts, cl = ClientPost("/clients/add", + `{"ids":["id"],"use_global_settings":false,"use_global_blocked_services":false,"name":"foo","filtering_enabled":false,"parental_enabled":false,"safesearch_enabled":false,"safebrowsing_enabled":false,"disallowed":false,"disallowed_rule":""}`, + ) + err := cl.AddClients(types.Client{Name: "foo", Ids: []string{"id"}}) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should update Clients", func() { + ts, cl = ClientPost("/clients/update", + `{"name":"foo","data":{"ids":["id"],"use_global_settings":false,"use_global_blocked_services":false,"name":"foo","filtering_enabled":false,"parental_enabled":false,"safesearch_enabled":false,"safebrowsing_enabled":false,"disallowed":false,"disallowed_rule":""}}`, + ) + err := cl.UpdateClients(types.Client{Name: "foo", Ids: []string{"id"}}) + Ω(err).ShouldNot(HaveOccurred()) + }) + It("should delete Clients", func() { + ts, cl = ClientPost("/clients/delete", + `{"ids":["id"],"use_global_settings":false,"use_global_blocked_services":false,"name":"foo","filtering_enabled":false,"parental_enabled":false,"safesearch_enabled":false,"safebrowsing_enabled":false,"disallowed":false,"disallowed_rule":""}`, + ) + err := cl.DeleteClients(types.Client{Name: "foo", Ids: []string{"id"}}) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + + Context("QueryLogConfig", func() { + It("should read QueryLogConfig", func() { + ts, cl = ClientGet("querylog_info.json", "/querylog_info") + qlc, err := cl.QueryLogConfig() + Ω(err).ShouldNot(HaveOccurred()) + Ω(qlc.Enabled).Should(BeTrue()) + Ω(qlc.Interval).Should(Equal(90)) + }) + It("should set QueryLogConfig", func() { + ts, cl = ClientPost("/querylog_config", `{"enabled":true,"interval":123,"anonymize_client_ip":true}`) + err := cl.SetQueryLogConfig(true, 123, true) + Ω(err).ShouldNot(HaveOccurred()) + }) + }) + Context("StatsConfig", func() { + It("should read StatsConfig", func() { + ts, cl = ClientGet("stats_info.json", "/stats_info") + sc, err := cl.StatsConfig() + Ω(err).ShouldNot(HaveOccurred()) + Ω(sc.Interval).Should(Equal(1)) + }) + It("should set StatsConfig", func() { + ts, cl = ClientPost("/stats_config", `{"interval":123}`) + err := cl.SetStatsConfig(123) Ω(err).ShouldNot(HaveOccurred()) }) }) }) -func Serve(file string) client.Client { +func ClientGet(file string, path string) (*httptest.Server, client.Client) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Ω(r.URL.Path).Should(Equal(types.DefaultAPIPath + path)) b, err := ioutil.ReadFile(filepath.Join("../../testdata", file)) Ω(err).ShouldNot(HaveOccurred()) + w.Header().Set("Content-Type", "application/json") _, err = w.Write(b) Ω(err).ShouldNot(HaveOccurred()) })) cl, err := client.New(types.AdGuardInstance{URL: ts.URL}) Ω(err).ShouldNot(HaveOccurred()) - return cl + return ts, cl +} + +func ClientPost(path string, content ...string) (*httptest.Server, client.Client) { + index := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Ω(r.URL.Path).Should(Equal(types.DefaultAPIPath + path)) + body, err := ioutil.ReadAll(r.Body) + Ω(err).ShouldNot(HaveOccurred()) + Ω(body).Should(Equal([]byte(content[index]))) + index++ + })) + + cl, err := client.New(types.AdGuardInstance{URL: ts.URL}) + Ω(err).ShouldNot(HaveOccurred()) + return ts, cl } diff --git a/pkg/types/types.go b/pkg/types/types.go index ca56358..0f37814 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -252,10 +252,10 @@ type Clients struct { // Client API struct type Client struct { - Ids []string `json:"ids"` - Tags []string `json:"tags"` - BlockedServices []string `json:"blocked_services"` - Upstreams []string `json:"upstreams"` + Ids []string `json:"ids,omitempty"` + Tags []string `json:"tags,omitempty"` + BlockedServices []string `json:"blocked_services,omitempty"` + Upstreams []string `json:"upstreams,omitempty"` UseGlobalSettings bool `json:"use_global_settings"` UseGlobalBlockedServices bool `json:"use_global_blocked_services"` diff --git a/testdata/blockedservices-list.json b/testdata/blockedservices-list.json new file mode 100644 index 0000000..a0c5b2d --- /dev/null +++ b/testdata/blockedservices-list.json @@ -0,0 +1,4 @@ +[ + "9gag", + "dailymotion" +] diff --git a/testdata/clients.json b/testdata/clients.json new file mode 100644 index 0000000..6ab1c3a --- /dev/null +++ b/testdata/clients.json @@ -0,0 +1,81 @@ +{ + "clients": [ + { + "ids": [ + "192.168.1.3" + ], + "tags": [ + "device_pc" + ], + "name": "PC", + "use_global_settings": true, + "filtering_enabled": false, + "parental_enabled": false, + "safesearch_enabled": false, + "safebrowsing_enabled": false, + "use_global_blocked_services": true, + "blocked_services": null, + "upstreams": null, + "whois_info": null, + "disallowed": false, + "disallowed_rule": "" + }, + { + "ids": [ + "192.168.1.2" + ], + "tags": [ + "device_phone" + ], + "name": "Phone LAN", + "use_global_settings": true, + "filtering_enabled": false, + "parental_enabled": false, + "safesearch_enabled": false, + "safebrowsing_enabled": false, + "use_global_blocked_services": false, + "blocked_services": [ + "facebook", + "ok", + "vk", + "mail_ru", + "qq" + ], + "upstreams": [], + "whois_info": null, + "disallowed": false, + "disallowed_rule": "" + } + ], + "auto_clients": [ + { + "ip": "127.0.0.1", + "name": "localhost", + "source": "etc/hosts", + "whois_info": {} + } + ], + "supported_tags": [ + "device_audio", + "device_camera", + "device_gameconsole", + "device_laptop", + "device_nas", + "device_other", + "device_pc", + "device_phone", + "device_printer", + "device_securityalarm", + "device_tablet", + "device_tv", + "os_android", + "os_ios", + "os_linux", + "os_macos", + "os_other", + "os_windows", + "user_admin", + "user_child", + "user_regular" + ] +} diff --git a/testdata/parental-status.json b/testdata/parental-status.json new file mode 100644 index 0000000..37c8cef --- /dev/null +++ b/testdata/parental-status.json @@ -0,0 +1,3 @@ +{ + "enabled": true +} \ No newline at end of file diff --git a/testdata/querylog_info.json b/testdata/querylog_info.json new file mode 100644 index 0000000..9418849 --- /dev/null +++ b/testdata/querylog_info.json @@ -0,0 +1,5 @@ +{ + "enabled": true, + "interval": 90, + "anonymize_client_ip": false +} \ No newline at end of file diff --git a/testdata/rewrite-list.json b/testdata/rewrite-list.json new file mode 100644 index 0000000..6e21e8a --- /dev/null +++ b/testdata/rewrite-list.json @@ -0,0 +1,10 @@ +[ + { + "domain": "foo.com", + "answer": "192.168.1.10" + }, + { + "domain": "bar.com", + "answer": "192.168.1.12" + } +] diff --git a/testdata/safebrowsing-status.json b/testdata/safebrowsing-status.json new file mode 100644 index 0000000..37c8cef --- /dev/null +++ b/testdata/safebrowsing-status.json @@ -0,0 +1,3 @@ +{ + "enabled": true +} \ No newline at end of file diff --git a/testdata/safesearch-status.json b/testdata/safesearch-status.json new file mode 100644 index 0000000..37c8cef --- /dev/null +++ b/testdata/safesearch-status.json @@ -0,0 +1,3 @@ +{ + "enabled": true +} \ No newline at end of file diff --git a/testdata/stats_info.json b/testdata/stats_info.json new file mode 100644 index 0000000..58d9443 --- /dev/null +++ b/testdata/stats_info.json @@ -0,0 +1,3 @@ +{ + "interval": 1 +} \ No newline at end of file diff --git a/testdata/status.json b/testdata/status.json new file mode 100644 index 0000000..c430486 --- /dev/null +++ b/testdata/status.json @@ -0,0 +1,12 @@ +{ + "dns_addresses": [ + "192.168.1.2" + ], + "dns_port": 53, + "http_port": 45158, + "protection_enabled": true, + "dhcp_available": true, + "running": true, + "version": "v0.105.2", + "language": "en" +} \ No newline at end of file