diff --git a/acls.go b/acls.go index 151fb3b6..281ea2c4 100644 --- a/acls.go +++ b/acls.go @@ -22,7 +22,7 @@ const errorInvalidTag = Error("invalid tag") const errorInvalidNamespace = Error("invalid namespace") const errorInvalidPortFormat = Error("invalid port format") -func (h *Headscale) LoadPolicy(path string) error { +func (h *Headscale) LoadAclPolicy(path string) error { policyFile, err := os.Open(path) if err != nil { return err @@ -40,7 +40,12 @@ func (h *Headscale) LoadPolicy(path string) error { } h.aclPolicy = &policy - return err + rules, err := h.generateACLRules() + if err != nil { + return err + } + h.aclRules = rules + return nil } func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) { diff --git a/api.go b/api.go index 92501a4e..ab805ef1 100644 --- a/api.go +++ b/api.go @@ -373,7 +373,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac DNS: []netaddr.IP{}, SearchPaths: []string{}, Domain: "foobar@example.com", - PacketFilter: tailcfg.FilterAllowAll, + PacketFilter: *h.aclRules, DERPMap: h.cfg.DerpMap, UserProfiles: []tailcfg.UserProfile{}, } diff --git a/app.go b/app.go index 4775c6ec..52e72bce 100644 --- a/app.go +++ b/app.go @@ -50,6 +50,7 @@ type Headscale struct { privateKey *wgkey.Private aclPolicy *ACLPolicy + aclRules *[]tailcfg.FilterRule pollMu sync.Mutex clientsPolling map[uint64]chan []byte // this is by all means a hackity hack @@ -84,7 +85,9 @@ func NewHeadscale(cfg Config) (*Headscale, error) { dbString: dbString, privateKey: privKey, publicKey: &pubKey, + aclRules: &tailcfg.FilterAllowAll, // default allowall } + err = h.initDB() if err != nil { return nil, err diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 52a9368e..c606b6d9 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -119,6 +119,13 @@ func getHeadscaleApp() (*headscale.Headscale, error) { if err != nil { return nil, err } + + // We are doing this here, as in the future could be cool to have it also hot-reload + err = h.LoadAclPolicy(absPath(viper.GetString("acl_policy_path"))) + if err != nil { + log.Printf("Could not load the ACL policy: %s", err) + } + return h, nil }