remove uuid on id type

This commit is contained in:
abhishek9686 2024-09-26 18:45:54 +04:00
parent 1d1c033988
commit 940ed8b2f0
4 changed files with 76 additions and 27 deletions

View file

@ -7,7 +7,6 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic"
@ -74,15 +73,14 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
// check if acl network exists err = logic.ValidateCreateAclReq(req)
_, err = logic.GetNetwork(req.NetworkID.String())
if err != nil { if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.NetworkID.String()), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
// check if acl exists // check if acl exists
acl := req acl := req
acl.ID = uuid.New() acl.GetID(req.NetworkID, req.Name)
acl.CreatedBy = user.UserName acl.CreatedBy = user.UserName
acl.CreatedAt = time.Now().UTC() acl.CreatedAt = time.Now().UTC()
acl.Default = false acl.Default = false
@ -107,7 +105,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
// @Success 200 {array} models.SuccessResponse // @Success 200 {array} models.SuccessResponse
// @Failure 500 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse
func updateAcl(w http.ResponseWriter, r *http.Request) { func updateAcl(w http.ResponseWriter, r *http.Request) {
var updateAcl models.Acl var updateAcl models.UpdateAclRequest
err := json.NewDecoder(r.Body).Decode(&updateAcl) err := json.NewDecoder(r.Body).Decode(&updateAcl)
if err != nil { if err != nil {
logger.Log(0, "error decoding request body: ", logger.Log(0, "error decoding request body: ",
@ -116,21 +114,37 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
return return
} }
acl, err := logic.GetAcl(updateAcl.ID.String()) acl, err := logic.GetAcl(updateAcl.Acl.ID)
if err != nil { if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
if !logic.IsAclPolicyValid(updateAcl) { if !logic.IsAclPolicyValid(updateAcl.Acl) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
return return
} }
err = logic.UpdateAcl(updateAcl, acl) if updateAcl.Acl.NetworkID != acl.NetworkID {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy, network id mismatch"), "badrequest"))
return
}
if updateAcl.NewName != "" {
//check if policy exists with same name
id := models.FormatAclID(updateAcl.Acl.NetworkID, updateAcl.NewName)
_, err := logic.GetAcl(id)
if err != nil {
logic.ReturnErrorResponse(w, r,
logic.FormatError(errors.New("policy already exists with name "+updateAcl.NewName), "badrequest"))
return
}
updateAcl.Acl.ID = id
updateAcl.Acl.Name = updateAcl.NewName
}
err = logic.UpdateAcl(updateAcl.Acl, acl)
if err != nil { if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return
} }
logic.ReturnSuccessResponse(w, r, "updated acl "+updateAcl.Name) logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name)
} }
// @Summary Delete Acl // @Summary Delete Acl
@ -145,7 +159,7 @@ func deleteAcl(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("acl id is required"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("acl id is required"), "badrequest"))
return return
} }
acl, err := logic.GetAcl(aclID) acl, err := logic.GetAcl(models.AclID(aclID))
if err != nil { if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return return

View file

@ -3,10 +3,10 @@ package logic
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sort" "sort"
"time" "time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
) )
@ -14,9 +14,9 @@ import (
// CreateDefaultAclNetworkPolicies - create default acl network policies // CreateDefaultAclNetworkPolicies - create default acl network policies
func CreateDefaultAclNetworkPolicies(netID models.NetworkID) { func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
defaultDeviceAcl := models.Acl{ defaultDeviceAcl := models.Acl{
ID: uuid.New(), ID: models.AclID(fmt.Sprintf("%s.%s", netID, "all-nodes")),
Default: true,
Name: "all-nodes", Name: "all-nodes",
Default: true,
NetworkID: netID, NetworkID: netID,
RuleType: models.DevicePolicy, RuleType: models.DevicePolicy,
Src: []models.AclPolicyTag{ Src: []models.AclPolicyTag{
@ -36,7 +36,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
} }
InsertAcl(defaultDeviceAcl) InsertAcl(defaultDeviceAcl)
defaultUserAcl := models.Acl{ defaultUserAcl := models.Acl{
ID: uuid.New(), ID: models.AclID(fmt.Sprintf("%s.%s", netID, "all-users")),
Default: true, Default: true,
Name: "all-users", Name: "all-users",
NetworkID: netID, NetworkID: netID,
@ -73,6 +73,19 @@ func DeleteDefaultNetworkPolicies(netId models.NetworkID) {
} }
} }
// ValidateCreateAclReq - validates create req for acl
func ValidateCreateAclReq(req models.Acl) error {
// check if acl network exists
_, err := GetNetwork(req.NetworkID.String())
if err != nil {
return errors.New("failed to get network details for " + req.NetworkID.String())
}
if req.Name == "" {
return errors.New("name is required")
}
return nil
}
// InsertAcl - creates acl policy // InsertAcl - creates acl policy
func InsertAcl(a models.Acl) error { func InsertAcl(a models.Acl) error {
d, err := json.Marshal(a) d, err := json.Marshal(a)
@ -83,9 +96,9 @@ func InsertAcl(a models.Acl) error {
} }
// GetAcl - gets acl info by id // GetAcl - gets acl info by id
func GetAcl(aID string) (models.Acl, error) { func GetAcl(aID models.AclID) (models.Acl, error) {
a := models.Acl{} a := models.Acl{}
d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID) d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID.String())
if err != nil { if err != nil {
return a, err return a, err
} }
@ -180,13 +193,16 @@ func IsAclPolicyValid(acl models.Acl) bool {
// UpdateAcl - updates allowed fields on acls and commits to DB // UpdateAcl - updates allowed fields on acls and commits to DB
func UpdateAcl(newAcl, acl models.Acl) error { func UpdateAcl(newAcl, acl models.Acl) error {
if newAcl.Name != "" {
acl.Name = newAcl.Name acl.Name = newAcl.Name
}
acl.Src = newAcl.Src acl.Src = newAcl.Src
acl.Dst = newAcl.Dst acl.Dst = newAcl.Dst
acl.AllowedDirection = newAcl.AllowedDirection acl.AllowedDirection = newAcl.AllowedDirection
acl.Enabled = newAcl.Enabled acl.Enabled = newAcl.Enabled
if acl.ID != newAcl.ID {
database.DeleteRecord(acl.ID.String(), database.ACLS_TABLE_NAME)
acl.ID = newAcl.ID
}
d, err := json.Marshal(acl) d, err := json.Marshal(acl)
if err != nil { if err != nil {
return err return err
@ -212,7 +228,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
// ListUserPolicies - lists all acl policies enforced on an user // ListUserPolicies - lists all acl policies enforced on an user
func ListUserPolicies(u models.User) []models.Acl { func ListUserPolicies(u models.User) []models.Acl {
data, err := database.FetchRecords(database.TAG_TABLE_NAME) data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
return []models.Acl{} return []models.Acl{}
} }
@ -245,7 +261,7 @@ func ListUserPolicies(u models.User) []models.Acl {
// ListUserPoliciesByNetwork - lists all acl user policies in a network // ListUserPoliciesByNetwork - lists all acl user policies in a network
func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl { func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
data, err := database.FetchRecords(database.TAG_TABLE_NAME) data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
return []models.Acl{} return []models.Acl{}
} }
@ -265,7 +281,7 @@ func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
// listDevicePolicies - lists all device policies in a network // listDevicePolicies - lists all device policies in a network
func listDevicePolicies(netID models.NetworkID) []models.Acl { func listDevicePolicies(netID models.NetworkID) []models.Acl {
data, err := database.FetchRecords(database.TAG_TABLE_NAME) data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
return []models.Acl{} return []models.Acl{}
} }
@ -285,7 +301,7 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl {
// ListAcls - lists all acl policies // ListAcls - lists all acl policies
func ListAcls(netID models.NetworkID) ([]models.Acl, error) { func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
data, err := database.FetchRecords(database.TAG_TABLE_NAME) data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) { if err != nil && !database.IsEmptyRecord(err) {
return []models.Acl{}, err return []models.Acl{}, err
} }

View file

@ -320,6 +320,7 @@ func syncUsers() {
if err == nil { if err == nil {
for _, netI := range networks { for _, netI := range networks {
logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID)) logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID))
logic.CreateDefaultAclNetworkPolicies(models.NetworkID(netI.NetID))
networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID) networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
for _, networkNodeI := range networkNodes { for _, networkNodeI := range networkNodes {
if networkNodeI.IsIngressGateway { if networkNodeI.IsIngressGateway {

View file

@ -1,11 +1,24 @@
package models package models
import ( import (
"fmt"
"time" "time"
"github.com/google/uuid"
) )
type AclID string
func (aID AclID) String() string {
return string(aID)
}
func (a *Acl) GetID(netID NetworkID, name string) {
a.ID = AclID(fmt.Sprintf("%s.%s", netID.String(), name))
}
func FormatAclID(netID NetworkID, name string) AclID {
return AclID(fmt.Sprintf("%s.%s", netID.String(), name))
}
// AllowedTrafficDirection - allowed direction of traffic // AllowedTrafficDirection - allowed direction of traffic
type AllowedTrafficDirection int type AllowedTrafficDirection int
@ -42,8 +55,13 @@ func (g AclGroupType) String() string {
return string(g) return string(g)
} }
type UpdateAclRequest struct {
Acl Acl
NewName string `json:"new_name"`
}
type Acl struct { type Acl struct {
ID uuid.UUID `json:"id"` ID AclID `json:"id"`
Default bool `json:"default"` Default bool `json:"default"`
Name string `json:"name"` Name string `json:"name"`
NetworkID NetworkID `json:"network_id"` NetworkID NetworkID `json:"network_id"`