diff --git a/controllers/acls.go b/controllers/acls.go index 1321a463..3bdfa6cf 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -7,7 +7,6 @@ import ( "net/url" "time" - "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -74,15 +73,14 @@ func createAcl(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - // check if acl network exists - _, err = logic.GetNetwork(req.NetworkID.String()) + err = logic.ValidateCreateAclReq(req) if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.NetworkID.String()), "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } // check if acl exists acl := req - acl.ID = uuid.New() + acl.GetID(req.NetworkID, req.Name) acl.CreatedBy = user.UserName acl.CreatedAt = time.Now().UTC() acl.Default = false @@ -107,7 +105,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) { // @Success 200 {array} models.SuccessResponse // @Failure 500 {object} models.ErrorResponse func updateAcl(w http.ResponseWriter, r *http.Request) { - var updateAcl models.Acl + var updateAcl models.UpdateAclRequest err := json.NewDecoder(r.Body).Decode(&updateAcl) if err != nil { logger.Log(0, "error decoding request body: ", @@ -116,21 +114,37 @@ func updateAcl(w http.ResponseWriter, r *http.Request) { return } - acl, err := logic.GetAcl(updateAcl.ID.String()) + acl, err := logic.GetAcl(updateAcl.Acl.ID) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - if !logic.IsAclPolicyValid(updateAcl) { + if !logic.IsAclPolicyValid(updateAcl.Acl) { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest")) return } - err = logic.UpdateAcl(updateAcl, acl) + if updateAcl.Acl.NetworkID != acl.NetworkID { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy, network id mismatch"), "badrequest")) + return + } + if updateAcl.NewName != "" { + //check if policy exists with same name + id := models.FormatAclID(updateAcl.Acl.NetworkID, updateAcl.NewName) + _, err := logic.GetAcl(id) + if err != nil { + logic.ReturnErrorResponse(w, r, + logic.FormatError(errors.New("policy already exists with name "+updateAcl.NewName), "badrequest")) + return + } + updateAcl.Acl.ID = id + updateAcl.Acl.Name = updateAcl.NewName + } + err = logic.UpdateAcl(updateAcl.Acl, acl) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - logic.ReturnSuccessResponse(w, r, "updated acl "+updateAcl.Name) + logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name) } // @Summary Delete Acl @@ -145,7 +159,7 @@ func deleteAcl(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("acl id is required"), "badrequest")) return } - acl, err := logic.GetAcl(aclID) + acl, err := logic.GetAcl(models.AclID(aclID)) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return diff --git a/logic/acls.go b/logic/acls.go index 382bd43f..54ca4dae 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -3,10 +3,10 @@ package logic import ( "encoding/json" "errors" + "fmt" "sort" "time" - "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/models" ) @@ -14,9 +14,9 @@ import ( // CreateDefaultAclNetworkPolicies - create default acl network policies func CreateDefaultAclNetworkPolicies(netID models.NetworkID) { defaultDeviceAcl := models.Acl{ - ID: uuid.New(), - Default: true, + ID: models.AclID(fmt.Sprintf("%s.%s", netID, "all-nodes")), Name: "all-nodes", + Default: true, NetworkID: netID, RuleType: models.DevicePolicy, Src: []models.AclPolicyTag{ @@ -36,7 +36,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) { } InsertAcl(defaultDeviceAcl) defaultUserAcl := models.Acl{ - ID: uuid.New(), + ID: models.AclID(fmt.Sprintf("%s.%s", netID, "all-users")), Default: true, Name: "all-users", NetworkID: netID, @@ -73,6 +73,19 @@ func DeleteDefaultNetworkPolicies(netId models.NetworkID) { } } +// ValidateCreateAclReq - validates create req for acl +func ValidateCreateAclReq(req models.Acl) error { + // check if acl network exists + _, err := GetNetwork(req.NetworkID.String()) + if err != nil { + return errors.New("failed to get network details for " + req.NetworkID.String()) + } + if req.Name == "" { + return errors.New("name is required") + } + return nil +} + // InsertAcl - creates acl policy func InsertAcl(a models.Acl) error { d, err := json.Marshal(a) @@ -83,9 +96,9 @@ func InsertAcl(a models.Acl) error { } // GetAcl - gets acl info by id -func GetAcl(aID string) (models.Acl, error) { +func GetAcl(aID models.AclID) (models.Acl, error) { a := models.Acl{} - d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID) + d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID.String()) if err != nil { return a, err } @@ -180,13 +193,16 @@ func IsAclPolicyValid(acl models.Acl) bool { // UpdateAcl - updates allowed fields on acls and commits to DB func UpdateAcl(newAcl, acl models.Acl) error { - if newAcl.Name != "" { - acl.Name = newAcl.Name - } + + acl.Name = newAcl.Name acl.Src = newAcl.Src acl.Dst = newAcl.Dst acl.AllowedDirection = newAcl.AllowedDirection acl.Enabled = newAcl.Enabled + if acl.ID != newAcl.ID { + database.DeleteRecord(acl.ID.String(), database.ACLS_TABLE_NAME) + acl.ID = newAcl.ID + } d, err := json.Marshal(acl) if err != nil { return err @@ -212,7 +228,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo // ListUserPolicies - lists all acl policies enforced on an user func ListUserPolicies(u models.User) []models.Acl { - data, err := database.FetchRecords(database.TAG_TABLE_NAME) + data, err := database.FetchRecords(database.ACLS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return []models.Acl{} } @@ -245,7 +261,7 @@ func ListUserPolicies(u models.User) []models.Acl { // ListUserPoliciesByNetwork - lists all acl user policies in a network func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl { - data, err := database.FetchRecords(database.TAG_TABLE_NAME) + data, err := database.FetchRecords(database.ACLS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return []models.Acl{} } @@ -265,7 +281,7 @@ func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl { // listDevicePolicies - lists all device policies in a network func listDevicePolicies(netID models.NetworkID) []models.Acl { - data, err := database.FetchRecords(database.TAG_TABLE_NAME) + data, err := database.FetchRecords(database.ACLS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return []models.Acl{} } @@ -285,7 +301,7 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl { // ListAcls - lists all acl policies func ListAcls(netID models.NetworkID) ([]models.Acl, error) { - data, err := database.FetchRecords(database.TAG_TABLE_NAME) + data, err := database.FetchRecords(database.ACLS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return []models.Acl{}, err } diff --git a/migrate/migrate.go b/migrate/migrate.go index a2d9b65d..850e438b 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -320,6 +320,7 @@ func syncUsers() { if err == nil { for _, netI := range networks { logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID)) + logic.CreateDefaultAclNetworkPolicies(models.NetworkID(netI.NetID)) networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID) for _, networkNodeI := range networkNodes { if networkNodeI.IsIngressGateway { diff --git a/models/acl.go b/models/acl.go index 17386149..acd1decb 100644 --- a/models/acl.go +++ b/models/acl.go @@ -1,11 +1,24 @@ package models import ( + "fmt" "time" - - "github.com/google/uuid" ) +type AclID string + +func (aID AclID) String() string { + return string(aID) +} + +func (a *Acl) GetID(netID NetworkID, name string) { + a.ID = AclID(fmt.Sprintf("%s.%s", netID.String(), name)) +} + +func FormatAclID(netID NetworkID, name string) AclID { + return AclID(fmt.Sprintf("%s.%s", netID.String(), name)) +} + // AllowedTrafficDirection - allowed direction of traffic type AllowedTrafficDirection int @@ -42,8 +55,13 @@ func (g AclGroupType) String() string { return string(g) } +type UpdateAclRequest struct { + Acl Acl + NewName string `json:"new_name"` +} + type Acl struct { - ID uuid.UUID `json:"id"` + ID AclID `json:"id"` Default bool `json:"default"` Name string `json:"name"` NetworkID NetworkID `json:"network_id"`