mirror of
https://github.com/usememos/memos.git
synced 2025-12-16 21:59:25 +08:00
fix(security): add missing authorization checks to various services (#5217)
This commit is contained in:
parent
df93120f60
commit
769dcd0cf9
6 changed files with 138 additions and 6 deletions
|
|
@ -38,8 +38,17 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
|
|||
response := &v1pb.ListIdentityProvidersResponse{
|
||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||
}
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
for _, identityProvider := range identityProviders {
|
||||
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
|
@ -58,10 +67,27 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge
|
|||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
|
||||
// Default to lowest-privilege role, update later based on real role
|
||||
currentUserRole := store.RoleUser
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err == nil && currentUser != nil {
|
||||
currentUserRole = currentUser.Role
|
||||
}
|
||||
|
||||
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
|
@ -95,6 +121,14 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
|
|||
}
|
||||
|
||||
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
||||
currentUser, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
|
|
@ -183,3 +217,13 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
|
||||
if userRole != store.RoleHost {
|
||||
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
|
||||
identityProvider.Config.GetOauth2Config().ClientSecret = ""
|
||||
}
|
||||
}
|
||||
|
||||
return identityProvider
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ import (
|
|||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
|
|
@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||
MemoID: &memo.ID,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -14,6 +14,13 @@ import (
|
|||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
|
|
@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
|||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
|
|
|
|||
|
|
@ -55,11 +55,35 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
|||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
||||
user, err := s.GetCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||
}
|
||||
|
||||
reactionID, err := ExtractReactionIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
|
||||
}
|
||||
|
||||
// Get reaction and check ownership
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ID: &reactionID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
if len(reactions) == 0 {
|
||||
return nil, status.Errorf(codes.NotFound, "reaction not found")
|
||||
}
|
||||
|
||||
reaction := reactions[0]
|
||||
if reaction.CreatorID != user.ID && !isSuperUser(user) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
||||
ID: reactionID,
|
||||
}); err != nil {
|
||||
|
|
|
|||
|
|
@ -233,6 +233,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
|||
Name: created.Name,
|
||||
}
|
||||
|
||||
// Test unauthenticated, should not contain client secret
|
||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
|
@ -241,7 +242,18 @@ func TestGetIdentityProvider(t *testing.T) {
|
|||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
|
||||
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
|
||||
|
||||
// Test as host user, should contain client secret
|
||||
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respHostUser)
|
||||
require.Equal(t, created.Name, respHostUser.Name)
|
||||
require.Equal(t, "Test Provider", respHostUser.Title)
|
||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
|
||||
require.NotNil(t, respHostUser.Config.GetOauth2Config())
|
||||
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
|
||||
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
|
||||
})
|
||||
|
||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||
|
|
@ -353,6 +365,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "identity-providers/1",
|
||||
|
|
@ -360,7 +379,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update_mask is required")
|
||||
})
|
||||
|
|
@ -369,6 +388,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.UpdateIdentityProviderRequest{
|
||||
IdentityProvider: &v1pb.IdentityProvider{
|
||||
Name: "invalid-name",
|
||||
|
|
@ -379,7 +405,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
||||
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
|
|
@ -445,11 +471,18 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
|||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create host user
|
||||
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set user context
|
||||
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||
|
||||
req := &v1pb.DeleteIdentityProviderRequest{
|
||||
Name: "invalid-name",
|
||||
}
|
||||
|
||||
_, err := ts.Service.DeleteIdentityProvider(ctx, req)
|
||||
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||
})
|
||||
|
|
|
|||
|
|
@ -169,6 +169,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
|
|||
// Unauthenticated or non-HOST users can only create normal users
|
||||
roleToAssign = store.RoleUser
|
||||
}
|
||||
|
||||
// Only allow user registration if it is enabled in the settings, or if the user is a superuser
|
||||
if currentUser == nil || !isSuperUser(currentUser) {
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
|
||||
}
|
||||
if workspaceGeneralSetting.DisallowUserRegistration {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue