mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-17 10:34:38 +08:00
enforce new acl policy access check
This commit is contained in:
parent
00b082d11c
commit
fcd3325173
3 changed files with 110 additions and 23 deletions
122
logic/acls.go
122
logic/acls.go
|
@ -2,8 +2,8 @@ package logic
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
|
@ -38,22 +38,22 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|||
case models.UserPolicy:
|
||||
// src list should only contain users
|
||||
for _, srcI := range acl.Src {
|
||||
userTagLi := strings.Split(srcI, ":")
|
||||
if len(userTagLi) < 2 {
|
||||
|
||||
if srcI.ID == "" || srcI.Value == "" {
|
||||
break
|
||||
}
|
||||
if userTagLi[0] != models.UserAclID.String() &&
|
||||
userTagLi[0] != models.UserGroupAclID.String() {
|
||||
if srcI.ID != models.UserAclID &&
|
||||
srcI.ID != models.UserGroupAclID {
|
||||
break
|
||||
}
|
||||
// check if user group is valid
|
||||
if userTagLi[0] == models.UserAclID.String() {
|
||||
_, err := GetUser(userTagLi[1])
|
||||
if srcI.ID == models.UserAclID {
|
||||
_, err := GetUser(srcI.Value)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
} else if userTagLi[0] == models.UserGroupAclID.String() {
|
||||
err := IsGroupValid(models.UserGroupID(userTagLi[1]))
|
||||
} else if srcI.ID == models.UserGroupAclID {
|
||||
err := IsGroupValid(models.UserGroupID(srcI.Value))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
@ -61,19 +61,19 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|||
|
||||
}
|
||||
for _, dstI := range acl.Dst {
|
||||
dstILi := strings.Split(dstI, ":")
|
||||
if len(dstILi) < 2 {
|
||||
|
||||
if dstI.ID == "" || dstI.Value == "" {
|
||||
break
|
||||
}
|
||||
if dstILi[0] == models.UserAclID.String() ||
|
||||
dstILi[0] == models.UserGroupAclID.String() {
|
||||
if dstI.ID == models.UserAclID ||
|
||||
dstI.ID == models.UserGroupAclID {
|
||||
break
|
||||
}
|
||||
if dstILi[0] != models.DeviceAclID.String() {
|
||||
if dstI.ID != models.DeviceAclID {
|
||||
break
|
||||
}
|
||||
// check if tag is valid
|
||||
_, err := GetTag(models.TagID(dstILi[1]))
|
||||
_, err := GetTag(models.TagID(dstI.Value))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
@ -81,20 +81,29 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|||
isValid = true
|
||||
case models.DevicePolicy:
|
||||
for _, srcI := range acl.Src {
|
||||
deviceTagLi := strings.Split(srcI, ":")
|
||||
if len(deviceTagLi) < 2 {
|
||||
if srcI.ID == "" || srcI.Value == "" {
|
||||
break
|
||||
}
|
||||
if deviceTagLi[0] != models.DeviceAclID.String() {
|
||||
if srcI.ID != models.DeviceAclID {
|
||||
break
|
||||
}
|
||||
// check if tag is valid
|
||||
_, err := GetTag(models.TagID(srcI.Value))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, dstI := range acl.Dst {
|
||||
deviceTagLi := strings.Split(dstI, ":")
|
||||
if len(deviceTagLi) < 2 {
|
||||
|
||||
if dstI.ID == "" || dstI.Value == "" {
|
||||
break
|
||||
}
|
||||
if deviceTagLi[0] != models.DeviceAclID.String() {
|
||||
if dstI.ID != models.DeviceAclID {
|
||||
break
|
||||
}
|
||||
// check if tag is valid
|
||||
_, err := GetTag(models.TagID(dstI.Value))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -124,6 +133,36 @@ func DeleteAcl(a models.Acl) error {
|
|||
return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID.String())
|
||||
}
|
||||
|
||||
func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
|
||||
acls, _ := ListAcls(netID)
|
||||
for _, acl := range acls {
|
||||
if acl.Default && acl.RuleType == ruleType {
|
||||
return acl, nil
|
||||
}
|
||||
}
|
||||
return models.Acl{}, errors.New("default rule not found")
|
||||
}
|
||||
|
||||
// listDevicePolicies - lists all device policies in a network
|
||||
func listDevicePolicies(netID models.NetworkID) []models.Acl {
|
||||
data, err := database.FetchRecords(database.TAG_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return []models.Acl{}
|
||||
}
|
||||
acls := []models.Acl{}
|
||||
for _, dataI := range data {
|
||||
acl := models.Acl{}
|
||||
err := json.Unmarshal([]byte(dataI), &acl)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
|
||||
acls = append(acls, acl)
|
||||
}
|
||||
}
|
||||
return acls
|
||||
}
|
||||
|
||||
// ListAcls - lists all acl policies
|
||||
func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
|
||||
data, err := database.FetchRecords(database.TAG_TABLE_NAME)
|
||||
|
@ -144,6 +183,47 @@ func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
|
|||
return acls, nil
|
||||
}
|
||||
|
||||
func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
|
||||
aclValueMap := make(map[string]struct{})
|
||||
for _, aclTagI := range acltags {
|
||||
aclValueMap[aclTagI.ID.String()] = struct{}{}
|
||||
}
|
||||
return aclValueMap
|
||||
}
|
||||
|
||||
func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
|
||||
// check default policy if all allowed return true
|
||||
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
if err == nil {
|
||||
if defaultPolicy.Enabled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// list device policies
|
||||
policies := listDevicePolicies(models.NetworkID(peer.Network))
|
||||
for _, policy := range policies {
|
||||
srcMap := convAclTagToValueMap(policy.Src)
|
||||
dstMap := convAclTagToValueMap(policy.Dst)
|
||||
for tagID := range peer.Tags {
|
||||
if _, ok := dstMap[tagID.String()]; ok {
|
||||
for tagID := range node.Tags {
|
||||
if _, ok := srcMap[tagID.String()]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, ok := srcMap[tagID.String()]; ok {
|
||||
for tagID := range node.Tags {
|
||||
if _, ok := dstMap[tagID.String()]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SortTagEntrys - Sorts slice of Tag entries by their id
|
||||
func SortAclEntrys(acls []models.Acl) {
|
||||
sort.Slice(acls, func(i, j int) bool {
|
||||
|
|
|
@ -241,6 +241,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
|||
!peer.PendingDelete &&
|
||||
peer.Connected &&
|
||||
nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
|
||||
IsNodeAllowedToCommunicate(node, peer) &&
|
||||
(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
|
||||
peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
|
||||
}
|
||||
|
|
|
@ -23,6 +23,11 @@ const (
|
|||
DevicePolicy AclPolicyType = "device-policy"
|
||||
)
|
||||
|
||||
type AclPolicyTag struct {
|
||||
ID AclGroupType `json:"id"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type AclGroupType string
|
||||
|
||||
const (
|
||||
|
@ -39,11 +44,12 @@ func (g AclGroupType) String() string {
|
|||
|
||||
type Acl struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Default bool `json:"default"`
|
||||
Name string `json:"name"`
|
||||
NetworkID NetworkID `json:"network_id"`
|
||||
RuleType AclPolicyType `json:"policy_type"`
|
||||
Src []string `json:"src_type"`
|
||||
Dst []string `json:"dst_type"`
|
||||
Src []AclPolicyTag `json:"src_type"`
|
||||
Dst []AclPolicyTag `json:"dst_type"`
|
||||
AllowedDirection AllowedTrafficDirection `json:"allowed_traffic_direction"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
|
|
Loading…
Add table
Reference in a new issue