diff --git a/logic/user_mgmt.go b/logic/user_mgmt.go index c85a6b5b..23b4e93c 100644 --- a/logic/user_mgmt.go +++ b/logic/user_mgmt.go @@ -50,7 +50,7 @@ var MigrateUserRoleAndGroups = func(u models.User) { } -var MigrateGroups = func() {} +var MigrateToUUIDs = func() {} var UpdateUserGwAccess = func(currentUser, changeUser models.User) {} diff --git a/logic/users.go b/logic/users.go index 4b9b5171..cbd453c7 100644 --- a/logic/users.go +++ b/logic/users.go @@ -199,6 +199,7 @@ func ListUserInvites() ([]models.UserInvite, error) { func DeleteUserInvite(email string) error { return database.DeleteRecord(database.USER_INVITES_TABLE_NAME, email) } + func ValidateAndApproveUserInvite(email, code string) error { in, err := GetUserInvite(email) if err != nil { diff --git a/migrate/migrate.go b/migrate/migrate.go index 83c240f2..a6c26cf4 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -29,7 +29,7 @@ func Run() { assignSuperAdmin() createDefaultTagsAndPolicies() removeOldUserGrps() - syncGroups() + migrateToUUIDs() syncUsers() updateHosts() updateNodes() @@ -394,8 +394,8 @@ func MigrateEmqx() { } -func syncGroups() { - logic.MigrateGroups() +func migrateToUUIDs() { + logic.MigrateToUUIDs() } func syncUsers() { diff --git a/pro/auth/auth.go b/pro/auth/auth.go index 70c9de13..33c271ff 100644 --- a/pro/auth/auth.go +++ b/pro/auth/auth.go @@ -1,9 +1,11 @@ package auth import ( + "encoding/json" "errors" "fmt" "net/http" + "strconv" "strings" "time" @@ -34,12 +36,38 @@ const ( // OAuthUser - generic OAuth strategy user type OAuthUser struct { - ID string `json:"id" bson:"id"` - Name string `json:"name" bson:"name"` - Email string `json:"email" bson:"email"` - Login string `json:"login" bson:"login"` - UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"` - AccessToken string `json:"accesstoken" bson:"accesstoken"` + ID StringOrInt `json:"id" bson:"id"` + Name string `json:"name" bson:"name"` + Email string `json:"email" bson:"email"` + Login string `json:"login" bson:"login"` + UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"` + AccessToken string `json:"accesstoken" bson:"accesstoken"` +} + +// TODO: this is a very poor solution. +// We should not return the same OAuthUser for different +// IdPs. They should have the user that their APIs return. +// But that's a very big change. So, making do with this +// for now. + +type StringOrInt string + +func (s *StringOrInt) UnmarshalJSON(data []byte) error { + // Try to unmarshal as string directly + var strVal string + if err := json.Unmarshal(data, &strVal); err == nil { + *s = StringOrInt(strVal) + return nil + } + + // Try to unmarshal as int and convert to string + var intVal int + if err := json.Unmarshal(data, &intVal); err == nil { + *s = StringOrInt(strconv.Itoa(intVal)) + return nil + } + + return fmt.Errorf("cannot unmarshal %s into StringOrInt", string(data)) } var ( diff --git a/pro/auth/azure-ad.go b/pro/auth/azure-ad.go index f6ce5d63..e5d9d4a8 100644 --- a/pro/auth/azure-ad.go +++ b/pro/auth/azure-ad.go @@ -111,7 +111,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.ID + user.ExternalIdentityProviderID = string(content.ID) if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -125,7 +125,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { } err = logic.InsertPendingUser(&models.User{ UserName: content.Email, - ExternalIdentityProviderID: content.ID, + ExternalIdentityProviderID: string(content.ID), AuthType: models.OAuth, }) if err != nil { @@ -243,7 +243,6 @@ func getAzureUserInfo(state string, code string) (*OAuthUser, error) { } if userInfo.Email == "" && userInfo.UserPrincipalName != "" { userInfo.Email = userInfo.UserPrincipalName - } if userInfo.Email == "" { err = errors.New("failed to fetch user email from SSO state") diff --git a/pro/auth/github.go b/pro/auth/github.go index a7d468d5..1bd8cc63 100644 --- a/pro/auth/github.go +++ b/pro/auth/github.go @@ -111,7 +111,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.ID + user.ExternalIdentityProviderID = string(content.ID) if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -125,7 +125,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { } err = logic.InsertPendingUser(&models.User{ UserName: content.Email, - ExternalIdentityProviderID: content.ID, + ExternalIdentityProviderID: string(content.ID), AuthType: models.OAuth, }) if err != nil { diff --git a/pro/auth/google.go b/pro/auth/google.go index 767645f9..e127edee 100644 --- a/pro/auth/google.go +++ b/pro/auth/google.go @@ -106,7 +106,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { } err = logic.InsertPendingUser(&models.User{ UserName: content.Email, - ExternalIdentityProviderID: content.ID, + ExternalIdentityProviderID: string(content.ID), AuthType: models.OAuth, }) if err != nil { diff --git a/pro/auth/headless_callback.go b/pro/auth/headless_callback.go index 2e13ddc9..c039a54a 100644 --- a/pro/auth/headless_callback.go +++ b/pro/auth/headless_callback.go @@ -65,7 +65,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) { if database.IsEmptyRecord(err) { // user must not exist, so try to make one err = logic.InsertPendingUser(&models.User{ UserName: userClaims.getUserName(), - ExternalIdentityProviderID: userClaims.ID, + ExternalIdentityProviderID: string(userClaims.ID), AuthType: models.OAuth, }) if err != nil { diff --git a/pro/auth/oidc.go b/pro/auth/oidc.go index 30fdd08f..d88cb4eb 100644 --- a/pro/auth/oidc.go +++ b/pro/auth/oidc.go @@ -102,7 +102,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - user.ExternalIdentityProviderID = content.ID + user.ExternalIdentityProviderID = string(content.ID) if err = logic.CreateUser(&user); err != nil { handleSomethingWentWrong(w) return @@ -116,7 +116,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { } err = logic.InsertPendingUser(&models.User{ UserName: content.Email, - ExternalIdentityProviderID: content.ID, + ExternalIdentityProviderID: string(content.ID), AuthType: models.OAuth, }) if err != nil { @@ -232,7 +232,7 @@ func getOIDCUserInfo(state string, code string) (u *OAuthUser, e error) { e = fmt.Errorf("error when claiming OIDCUser: \"%s\"", err.Error()) } - u.ID = idToken.Subject + u.ID = StringOrInt(idToken.Subject) return } diff --git a/pro/initialize.go b/pro/initialize.go index 7ecb2c87..cb3437e6 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -136,7 +136,7 @@ func InitPro() { logic.UpdateUserGwAccess = proLogic.UpdateUserGwAccess logic.CreateDefaultUserPolicies = proLogic.CreateDefaultUserPolicies logic.MigrateUserRoleAndGroups = proLogic.MigrateUserRoleAndGroups - logic.MigrateGroups = proLogic.MigrateGroups + logic.MigrateToUUIDs = proLogic.MigrateToUUIDs logic.IntialiseGroups = proLogic.UserGroupsInit logic.AddGlobalNetRolesToAdmins = proLogic.AddGlobalNetRolesToAdmins logic.GetUserGroupsInNetwork = proLogic.GetUserGroupsInNetwork diff --git a/pro/logic/migrate.go b/pro/logic/migrate.go index 5fac1ead..b265c0c0 100644 --- a/pro/logic/migrate.go +++ b/pro/logic/migrate.go @@ -9,13 +9,51 @@ import ( "github.com/gravitl/netmaker/models" ) -func MigrateGroups() { +func MigrateToUUIDs() { + roles, err := ListNetworkRoles() + if err != nil { + return + } + + rolesMapping := make(map[models.UserRoleID]models.UserRoleID) + + for _, role := range roles { + if role.Default { + continue + } + + _, err := uuid.Parse(string(role.ID)) + if err == nil { + // role id is already an uuid, so no need to update + continue + } + + oldRoleID := role.ID + role.ID = models.UserRoleID(uuid.NewString()) + rolesMapping[oldRoleID] = role.ID + + roleBytes, err := json.Marshal(role) + if err != nil { + continue + } + + err = database.Insert(role.ID.String(), string(roleBytes), database.USER_PERMISSIONS_TABLE_NAME) + if err != nil { + continue + } + + err = database.DeleteRecord(database.USER_PERMISSIONS_TABLE_NAME, oldRoleID.String()) + if err != nil { + continue + } + } + groups, err := ListUserGroups() if err != nil { return } - groupMapping := make(map[models.UserGroupID]models.UserGroupID) + groupsMapping := make(map[models.UserGroupID]models.UserGroupID) for _, group := range groups { if group.Default { @@ -30,7 +68,22 @@ func MigrateGroups() { oldGroupID := group.ID group.ID = models.UserGroupID(uuid.NewString()) - groupMapping[oldGroupID] = group.ID + groupsMapping[oldGroupID] = group.ID + + var groupPermissions = make(map[models.NetworkID]map[models.UserRoleID]struct{}) + for networkID, networkRoles := range group.NetworkRoles { + groupPermissions[networkID] = make(map[models.UserRoleID]struct{}) + for roleID := range networkRoles { + newRoleID, ok := rolesMapping[roleID] + if !ok { + groupPermissions[networkID][roleID] = struct{}{} + } else { + groupPermissions[networkID][newRoleID] = struct{}{} + } + } + } + + group.NetworkRoles = groupPermissions groupBytes, err := json.Marshal(group) if err != nil { @@ -48,6 +101,11 @@ func MigrateGroups() { } } + // if no changes were made, there are no references to be updated. + if len(rolesMapping) == 0 && len(groupsMapping) == 0 { + return + } + users, err := logic.GetUsersDB() if err != nil { return @@ -56,7 +114,7 @@ func MigrateGroups() { for _, user := range users { userGroups := make(map[models.UserGroupID]struct{}) for groupID := range user.UserGroups { - newGroupID, ok := groupMapping[groupID] + newGroupID, ok := groupsMapping[groupID] if !ok { userGroups[groupID] = struct{}{} } else { @@ -65,7 +123,81 @@ func MigrateGroups() { } user.UserGroups = userGroups - logic.UpsertUser(user) + err = logic.UpsertUser(user) + if err != nil { + continue + } + } + + for _, acl := range logic.ListAcls() { + srcList := make([]models.AclPolicyTag, len(acl.Src)) + for i, src := range acl.Src { + if src.ID == models.UserGroupAclID { + newGroupID, ok := groupsMapping[models.UserGroupID(src.Value)] + if ok { + src.Value = newGroupID.String() + } + } + + srcList[i] = src + } + + dstList := make([]models.AclPolicyTag, len(acl.Dst)) + for i, dst := range acl.Dst { + if dst.ID == models.UserGroupAclID { + newGroupID, ok := groupsMapping[models.UserGroupID(dst.Value)] + if ok { + dst.Value = newGroupID.String() + } + } + + dstList[i] = dst + } + + err = logic.UpsertAcl(acl) + if err != nil { + continue + } + } + + invites, err := logic.ListUserInvites() + if err != nil { + return + } + + for _, invite := range invites { + userGroups := make(map[models.UserGroupID]struct{}) + for groupID := range invite.UserGroups { + newGroupID, ok := groupsMapping[groupID] + if !ok { + invite.UserGroups[groupID] = struct{}{} + } else { + invite.UserGroups[newGroupID] = struct{}{} + } + } + + invite.UserGroups = userGroups + + userPermissions := make(map[models.NetworkID]map[models.UserRoleID]struct{}) + + for networkID, networkRoles := range invite.NetworkRoles { + userPermissions[networkID] = make(map[models.UserRoleID]struct{}) + for roleID := range networkRoles { + newRoleID, ok := rolesMapping[roleID] + if !ok { + userPermissions[networkID][roleID] = struct{}{} + } else { + userPermissions[networkID][newRoleID] = struct{}{} + } + } + } + + invite.NetworkRoles = userPermissions + + err = logic.InsertUserInvite(invite) + if err != nil { + continue + } } } diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index 496b54e2..976b43f1 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -18,6 +18,8 @@ import ( var ( globalNetworksAdminGroupID = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkAdmin)) globalNetworksUserGroupID = models.UserGroupID(fmt.Sprintf("global-%s-grp", models.NetworkUser)) + globalNetworksAdminRoleID = models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin)) + globalNetworksUserRoleID = models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)) ) var ServiceUserPermissionTemplate = models.UserRolePermissionTemplate{ @@ -34,7 +36,7 @@ var PlatformUserUserPermissionTemplate = models.UserRolePermissionTemplate{ } var NetworkAdminAllPermissionTemplate = models.UserRolePermissionTemplate{ - ID: models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin)), + ID: globalNetworksAdminRoleID, Name: "Network Admins", MetaData: "can manage configuration of all networks", Default: true, @@ -43,7 +45,7 @@ var NetworkAdminAllPermissionTemplate = models.UserRolePermissionTemplate{ } var NetworkUserAllPermissionTemplate = models.UserRolePermissionTemplate{ - ID: models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)), + ID: globalNetworksUserRoleID, Name: "Network Users", MetaData: "Can connect to nodes in your networks via Netmaker Desktop App.", Default: true, @@ -123,7 +125,7 @@ func UserGroupsInit() { MetaData: "can manage configuration of all networks", NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ models.AllNetworks: { - models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkAdmin)): {}, + globalNetworksAdminRoleID: {}, }, }, } @@ -133,7 +135,7 @@ func UserGroupsInit() { Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ models.AllNetworks: { - models.UserRoleID(fmt.Sprintf("global-%s", models.NetworkUser)): {}, + globalNetworksUserRoleID: {}, }, }, MetaData: "Provides read-only dashboard access to platform users and allows connection to network nodes via the Netmaker Desktop App.", @@ -149,7 +151,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { return } var NetworkAdminPermissionTemplate = models.UserRolePermissionTemplate{ - ID: models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkAdmin)), + ID: GetDefaultNetworkAdminRoleID(netID), Name: fmt.Sprintf("%s Admin", netID), MetaData: fmt.Sprintf("can manage your network `%s` configuration.", netID), Default: true, @@ -159,7 +161,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { } var NetworkUserPermissionTemplate = models.UserRolePermissionTemplate{ - ID: models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkUser)), + ID: GetDefaultNetworkUserRoleID(netID), Name: fmt.Sprintf("%s User", netID), MetaData: fmt.Sprintf("Can connect to nodes in your network `%s` via Netmaker Desktop App.", netID), Default: true, @@ -226,7 +228,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ netID: { - models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkAdmin)): {}, + GetDefaultNetworkAdminRoleID(netID): {}, }, }, MetaData: fmt.Sprintf("can manage your network `%s` configuration including adding and removing devices.", netID), @@ -237,7 +239,7 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) { Default: true, NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{ netID: { - models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkUser)): {}, + GetDefaultNetworkUserRoleID(netID): {}, }, }, MetaData: fmt.Sprintf("Can connect to nodes in your network `%s` via Netmaker Desktop App. Platform users will have read-only access to the the dashboard.", netID), @@ -402,14 +404,32 @@ func ValidateUpdateRoleReq(userRole *models.UserRolePermissionTemplate) error { // CreateRole - inserts new role into DB func CreateRole(r models.UserRolePermissionTemplate) error { - // check if role already exists - if r.ID.String() == "" { - return errors.New("role id cannot be empty") + // default roles are currently created directly in the db. + // this check is only to prevent future errors. + if r.Default && r.ID == "" { + return errors.New("role id cannot be empty for default role") } - _, err := database.FetchRecord(database.USER_PERMISSIONS_TABLE_NAME, r.ID.String()) - if err == nil { - return errors.New("role already exists") + + if !r.Default { + r.ID = models.UserRoleID(uuid.NewString()) } + + // check if the role already exists + if r.Name == "" { + return errors.New("role name cannot be empty") + } + + roles, err := ListNetworkRoles() + if err != nil { + return err + } + + for _, role := range roles { + if role.Name == r.Name { + return errors.New("role already exists") + } + } + d, err := json.Marshal(r) if err != nil { return err @@ -585,6 +605,14 @@ func GetDefaultNetworkUserGroupID(networkID models.NetworkID) models.UserGroupID return models.UserGroupID(fmt.Sprintf("%s-%s-grp", networkID, models.NetworkUser)) } +func GetDefaultNetworkAdminRoleID(networkID models.NetworkID) models.UserRoleID { + return models.UserRoleID(fmt.Sprintf("%s-%s", networkID, models.NetworkAdmin)) +} + +func GetDefaultNetworkUserRoleID(networkID models.NetworkID) models.UserRoleID { + return models.UserRoleID(fmt.Sprintf("%s-%s", networkID, models.NetworkUser)) +} + // ListUserGroups - lists user groups func ListUserGroups() ([]models.UserGroup, error) { data, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)