enforce new acl policy access check

This commit is contained in:
abhishek9686 2024-09-25 16:06:08 +04:00
parent 00b082d11c
commit fcd3325173
3 changed files with 110 additions and 23 deletions

View file

@ -2,8 +2,8 @@ package logic
import (
"encoding/json"
"errors"
"sort"
"strings"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
@ -38,22 +38,22 @@ func IsAclPolicyValid(acl models.Acl) bool {
case models.UserPolicy:
// src list should only contain users
for _, srcI := range acl.Src {
userTagLi := strings.Split(srcI, ":")
if len(userTagLi) < 2 {
if srcI.ID == "" || srcI.Value == "" {
break
}
if userTagLi[0] != models.UserAclID.String() &&
userTagLi[0] != models.UserGroupAclID.String() {
if srcI.ID != models.UserAclID &&
srcI.ID != models.UserGroupAclID {
break
}
// check if user group is valid
if userTagLi[0] == models.UserAclID.String() {
_, err := GetUser(userTagLi[1])
if srcI.ID == models.UserAclID {
_, err := GetUser(srcI.Value)
if err != nil {
break
}
} else if userTagLi[0] == models.UserGroupAclID.String() {
err := IsGroupValid(models.UserGroupID(userTagLi[1]))
} else if srcI.ID == models.UserGroupAclID {
err := IsGroupValid(models.UserGroupID(srcI.Value))
if err != nil {
break
}
@ -61,19 +61,19 @@ func IsAclPolicyValid(acl models.Acl) bool {
}
for _, dstI := range acl.Dst {
dstILi := strings.Split(dstI, ":")
if len(dstILi) < 2 {
if dstI.ID == "" || dstI.Value == "" {
break
}
if dstILi[0] == models.UserAclID.String() ||
dstILi[0] == models.UserGroupAclID.String() {
if dstI.ID == models.UserAclID ||
dstI.ID == models.UserGroupAclID {
break
}
if dstILi[0] != models.DeviceAclID.String() {
if dstI.ID != models.DeviceAclID {
break
}
// check if tag is valid
_, err := GetTag(models.TagID(dstILi[1]))
_, err := GetTag(models.TagID(dstI.Value))
if err != nil {
break
}
@ -81,20 +81,29 @@ func IsAclPolicyValid(acl models.Acl) bool {
isValid = true
case models.DevicePolicy:
for _, srcI := range acl.Src {
deviceTagLi := strings.Split(srcI, ":")
if len(deviceTagLi) < 2 {
if srcI.ID == "" || srcI.Value == "" {
break
}
if deviceTagLi[0] != models.DeviceAclID.String() {
if srcI.ID != models.DeviceAclID {
break
}
// check if tag is valid
_, err := GetTag(models.TagID(srcI.Value))
if err != nil {
break
}
}
for _, dstI := range acl.Dst {
deviceTagLi := strings.Split(dstI, ":")
if len(deviceTagLi) < 2 {
if dstI.ID == "" || dstI.Value == "" {
break
}
if deviceTagLi[0] != models.DeviceAclID.String() {
if dstI.ID != models.DeviceAclID {
break
}
// check if tag is valid
_, err := GetTag(models.TagID(dstI.Value))
if err != nil {
break
}
}
@ -124,6 +133,36 @@ func DeleteAcl(a models.Acl) error {
return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID.String())
}
func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
acls, _ := ListAcls(netID)
for _, acl := range acls {
if acl.Default && acl.RuleType == ruleType {
return acl, nil
}
}
return models.Acl{}, errors.New("default rule not found")
}
// listDevicePolicies - lists all device policies in a network
func listDevicePolicies(netID models.NetworkID) []models.Acl {
data, err := database.FetchRecords(database.TAG_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return []models.Acl{}
}
acls := []models.Acl{}
for _, dataI := range data {
acl := models.Acl{}
err := json.Unmarshal([]byte(dataI), &acl)
if err != nil {
continue
}
if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
acls = append(acls, acl)
}
}
return acls
}
// ListAcls - lists all acl policies
func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
data, err := database.FetchRecords(database.TAG_TABLE_NAME)
@ -144,6 +183,47 @@ func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
return acls, nil
}
func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
aclValueMap := make(map[string]struct{})
for _, aclTagI := range acltags {
aclValueMap[aclTagI.ID.String()] = struct{}{}
}
return aclValueMap
}
func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
// check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
if err == nil {
if defaultPolicy.Enabled {
return true
}
}
// list device policies
policies := listDevicePolicies(models.NetworkID(peer.Network))
for _, policy := range policies {
srcMap := convAclTagToValueMap(policy.Src)
dstMap := convAclTagToValueMap(policy.Dst)
for tagID := range peer.Tags {
if _, ok := dstMap[tagID.String()]; ok {
for tagID := range node.Tags {
if _, ok := srcMap[tagID.String()]; ok {
return true
}
}
}
if _, ok := srcMap[tagID.String()]; ok {
for tagID := range node.Tags {
if _, ok := dstMap[tagID.String()]; ok {
return true
}
}
}
}
}
return false
}
// SortTagEntrys - Sorts slice of Tag entries by their id
func SortAclEntrys(acls []models.Acl) {
sort.Slice(acls, func(i, j int) bool {

View file

@ -241,6 +241,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
!peer.PendingDelete &&
peer.Connected &&
nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
IsNodeAllowedToCommunicate(node, peer) &&
(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
}

View file

@ -23,6 +23,11 @@ const (
DevicePolicy AclPolicyType = "device-policy"
)
type AclPolicyTag struct {
ID AclGroupType `json:"id"`
Value string `json:"value"`
}
type AclGroupType string
const (
@ -39,11 +44,12 @@ func (g AclGroupType) String() string {
type Acl struct {
ID uuid.UUID `json:"id"`
Default bool `json:"default"`
Name string `json:"name"`
NetworkID NetworkID `json:"network_id"`
RuleType AclPolicyType `json:"policy_type"`
Src []string `json:"src_type"`
Dst []string `json:"dst_type"`
Src []AclPolicyTag `json:"src_type"`
Dst []AclPolicyTag `json:"dst_type"`
AllowedDirection AllowedTrafficDirection `json:"allowed_traffic_direction"`
Enabled bool `json:"enabled"`
CreatedBy string `json:"created_by"`