headscale/acls.go

271 lines
5.9 KiB
Go
Raw Normal View History

2021-07-03 17:55:32 +08:00
package headscale
import (
"encoding/json"
2021-07-03 23:31:32 +08:00
"fmt"
2021-07-03 17:55:32 +08:00
"io"
"log"
2021-07-03 17:55:32 +08:00
"os"
"strconv"
2021-07-03 23:31:32 +08:00
"strings"
2021-07-03 17:55:32 +08:00
"github.com/tailscale/hujson"
2021-07-03 23:31:32 +08:00
"inet.af/netaddr"
"tailscale.com/tailcfg"
2021-07-03 17:55:32 +08:00
)
2021-07-03 23:31:32 +08:00
const errorEmptyPolicy = Error("empty policy")
const errorInvalidAction = Error("invalid action")
const errorInvalidUserSection = Error("invalid user section")
const errorInvalidGroup = Error("invalid group")
const errorInvalidTag = Error("invalid tag")
const errorInvalidNamespace = Error("invalid namespace")
const errorInvalidPortFormat = Error("invalid port format")
2021-07-03 17:55:32 +08:00
2021-07-03 23:31:32 +08:00
func (h *Headscale) LoadPolicy(path string) error {
2021-07-03 17:55:32 +08:00
policyFile, err := os.Open(path)
if err != nil {
2021-07-03 23:31:32 +08:00
return err
2021-07-03 17:55:32 +08:00
}
defer policyFile.Close()
var policy ACLPolicy
b, err := io.ReadAll(policyFile)
if err != nil {
2021-07-03 23:31:32 +08:00
return err
2021-07-03 17:55:32 +08:00
}
err = hujson.Unmarshal(b, &policy)
if policy.IsZero() {
2021-07-03 23:31:32 +08:00
return errorEmptyPolicy
2021-07-03 17:55:32 +08:00
}
2021-07-03 23:31:32 +08:00
h.aclPolicy = &policy
return err
}
func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for i, a := range h.aclPolicy.ACLs {
if a.Action != "accept" {
return nil, errorInvalidAction
}
r := tailcfg.FilterRule{}
srcIPs := []string{}
for j, u := range a.Users {
fmt.Printf("acl %d, user %d: ", i, j)
srcs, err := h.generateAclPolicySrcIP(u)
fmt.Printf(" -> %s\n", err)
if err != nil {
return nil, err
}
srcIPs = append(srcIPs, *srcs...)
}
r.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
for j, d := range a.Ports {
fmt.Printf("acl %d, port %d: ", i, j)
dests, err := h.generateAclPolicyDestPorts(d)
fmt.Printf(" -> %s\n", err)
if err != nil {
return nil, err
}
destPorts = append(destPorts, *dests...)
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs,
DstPorts: destPorts,
})
2021-07-03 23:31:32 +08:00
}
return &rules, nil
}
func (h *Headscale) generateAclPolicySrcIP(u string) (*[]string, error) {
return h.expandAlias(u)
}
func (h *Headscale) generateAclPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == 2 {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
expanded, err := h.expandAlias(alias)
if err != nil {
return nil, err
}
ports, err := h.expandPorts(tokens[len(tokens)-1])
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, d := range *expanded {
for _, p := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
}
dests = append(dests, pr)
}
}
return &dests, nil
}
func (h *Headscale) expandAlias(s string) (*[]string, error) {
if s == "*" {
fmt.Printf("%s -> wildcard", s)
2021-07-03 23:31:32 +08:00
return &[]string{"*"}, nil
}
if strings.HasPrefix(s, "group:") {
fmt.Printf("%s -> group", s)
if _, ok := h.aclPolicy.Groups[s]; !ok {
2021-07-03 23:31:32 +08:00
return nil, errorInvalidGroup
}
ips := []string{}
for _, n := range h.aclPolicy.Groups[s] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
return nil, errorInvalidNamespace
}
for _, node := range *nodes {
ips = append(ips, node.IPAddress)
}
}
return &ips, nil
2021-07-03 23:31:32 +08:00
}
if strings.HasPrefix(s, "tag:") {
fmt.Printf("%s -> tag", s)
if _, ok := h.aclPolicy.TagOwners[s]; !ok {
return nil, errorInvalidTag
}
// This will have HORRIBLE performance.
// We need to change the data model to better store tags
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
machines := []Machine{}
if err = db.Where("registered").Find(&machines).Error; err != nil {
log.Printf("Error accessing db: %s", err)
return nil, err
}
ips := []string{}
for _, m := range machines {
hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 {
hi, err := m.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(hi, &hostinfo)
if err != nil {
return nil, err
}
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
if s[4:] == t {
ips = append(ips, m.IPAddress)
break
}
}
}
}
return &ips, nil
2021-07-03 23:31:32 +08:00
}
n, err := h.GetNamespace(s)
2021-07-03 23:31:32 +08:00
if err == nil {
fmt.Printf("%s -> namespace %s", s, n.Name)
2021-07-03 23:31:32 +08:00
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
return nil, err
}
ips := []string{}
for _, n := range *nodes {
ips = append(ips, n.IPAddress)
}
return &ips, nil
}
if h, ok := h.aclPolicy.Hosts[s]; ok {
fmt.Printf("%s -> host %s", s, h)
2021-07-03 23:31:32 +08:00
return &[]string{h.String()}, nil
}
ip, err := netaddr.ParseIP(s)
2021-07-03 23:31:32 +08:00
if err == nil {
fmt.Printf(" %s -> ip %s", s, ip)
2021-07-03 23:31:32 +08:00
return &[]string{ip.String()}, nil
}
cidr, err := netaddr.ParseIPPrefix(s)
2021-07-03 23:31:32 +08:00
if err == nil {
fmt.Printf("%s -> cidr %s", s, cidr)
2021-07-03 23:31:32 +08:00
return &[]string{cidr.String()}, nil
}
fmt.Printf("%s: cannot be mapped to anything\n", s)
2021-07-03 23:31:32 +08:00
return nil, errorInvalidUserSection
2021-07-03 17:55:32 +08:00
}
func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
if s == "*" {
return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
}
ports := []tailcfg.PortRange{}
for _, p := range strings.Split(s, ",") {
rang := strings.Split(p, "-")
if len(rang) == 1 {
pi, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(pi),
Last: uint16(pi),
})
} else if len(rang) == 2 {
start, err := strconv.ParseUint(rang[0], 10, 16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], 10, 16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
} else {
return nil, errorInvalidPortFormat
}
}
return &ports, nil
}