mirror of
https://github.com/usememos/memos.git
synced 2025-12-18 14:50:13 +08:00
fix(security): add missing authorization checks to various services (#5217)
This commit is contained in:
parent
df93120f60
commit
769dcd0cf9
6 changed files with 138 additions and 6 deletions
|
|
@ -38,8 +38,17 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
|
||||||
response := &v1pb.ListIdentityProvidersResponse{
|
response := &v1pb.ListIdentityProvidersResponse{
|
||||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default to lowest-privilege role, update later based on real role
|
||||||
|
currentUserRole := store.RoleUser
|
||||||
|
currentUser, err := s.GetCurrentUser(ctx)
|
||||||
|
if err == nil && currentUser != nil {
|
||||||
|
currentUserRole = currentUser.Role
|
||||||
|
}
|
||||||
|
|
||||||
for _, identityProvider := range identityProviders {
|
for _, identityProvider := range identityProviders {
|
||||||
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
|
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||||
|
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
|
||||||
}
|
}
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
@ -58,10 +67,27 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge
|
||||||
if identityProvider == nil {
|
if identityProvider == nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||||
}
|
}
|
||||||
return convertIdentityProviderFromStore(identityProvider), nil
|
|
||||||
|
// Default to lowest-privilege role, update later based on real role
|
||||||
|
currentUserRole := store.RoleUser
|
||||||
|
currentUser, err := s.GetCurrentUser(ctx)
|
||||||
|
if err == nil && currentUser != nil {
|
||||||
|
currentUserRole = currentUser.Role
|
||||||
|
}
|
||||||
|
|
||||||
|
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
|
||||||
|
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||||
|
currentUser, err := s.GetCurrentUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||||
|
}
|
||||||
|
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||||
}
|
}
|
||||||
|
|
@ -95,6 +121,14 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
||||||
|
currentUser, err := s.GetCurrentUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||||
|
}
|
||||||
|
if currentUser == nil || currentUser.Role != store.RoleHost {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||||
|
|
@ -183,3 +217,13 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
|
||||||
|
if userRole != store.RoleHost {
|
||||||
|
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
|
||||||
|
identityProvider.Config.GetOauth2Config().ClientSecret = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return identityProvider
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
|
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
|
||||||
|
user, err := s.GetCurrentUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||||
|
}
|
||||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||||
|
|
@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||||
}
|
}
|
||||||
|
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
|
}
|
||||||
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
|
||||||
MemoID: &memo.ID,
|
MemoID: &memo.ID,
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
||||||
|
user, err := s.GetCurrentUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||||
|
}
|
||||||
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
memoUID, err := ExtractMemoUIDFromName(request.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||||
|
|
@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||||
}
|
}
|
||||||
|
if memo.CreatorID != user.ID && !isSuperUser(user) {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
|
}
|
||||||
referenceType := store.MemoRelationReference
|
referenceType := store.MemoRelationReference
|
||||||
// Delete all reference relations first.
|
// Delete all reference relations first.
|
||||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||||
|
|
|
||||||
|
|
@ -55,11 +55,35 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
||||||
|
user, err := s.GetCurrentUser(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
reactionID, err := ExtractReactionIDFromName(request.Name)
|
reactionID, err := ExtractReactionIDFromName(request.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get reaction and check ownership
|
||||||
|
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||||
|
ID: &reactionID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||||
|
}
|
||||||
|
if len(reactions) == 0 {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "reaction not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
reaction := reactions[0]
|
||||||
|
if reaction.CreatorID != user.ID && !isSuperUser(user) {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
||||||
ID: reactionID,
|
ID: reactionID,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -233,6 +233,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
||||||
Name: created.Name,
|
Name: created.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unauthenticated, should not contain client secret
|
||||||
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
|
|
@ -241,7 +242,18 @@ func TestGetIdentityProvider(t *testing.T) {
|
||||||
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
|
||||||
require.NotNil(t, resp.Config.GetOauth2Config())
|
require.NotNil(t, resp.Config.GetOauth2Config())
|
||||||
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
|
||||||
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
|
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
|
||||||
|
|
||||||
|
// Test as host user, should contain client secret
|
||||||
|
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, respHostUser)
|
||||||
|
require.Equal(t, created.Name, respHostUser.Name)
|
||||||
|
require.Equal(t, "Test Provider", respHostUser.Title)
|
||||||
|
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
|
||||||
|
require.NotNil(t, respHostUser.Config.GetOauth2Config())
|
||||||
|
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
|
||||||
|
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
t.Run("GetIdentityProvider not found", func(t *testing.T) {
|
||||||
|
|
@ -353,6 +365,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||||
ts := NewTestService(t)
|
ts := NewTestService(t)
|
||||||
defer ts.Cleanup()
|
defer ts.Cleanup()
|
||||||
|
|
||||||
|
// Create host user
|
||||||
|
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set user context
|
||||||
|
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||||
|
|
||||||
req := &v1pb.UpdateIdentityProviderRequest{
|
req := &v1pb.UpdateIdentityProviderRequest{
|
||||||
IdentityProvider: &v1pb.IdentityProvider{
|
IdentityProvider: &v1pb.IdentityProvider{
|
||||||
Name: "identity-providers/1",
|
Name: "identity-providers/1",
|
||||||
|
|
@ -360,7 +379,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "update_mask is required")
|
require.Contains(t, err.Error(), "update_mask is required")
|
||||||
})
|
})
|
||||||
|
|
@ -369,6 +388,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||||
ts := NewTestService(t)
|
ts := NewTestService(t)
|
||||||
defer ts.Cleanup()
|
defer ts.Cleanup()
|
||||||
|
|
||||||
|
// Create host user
|
||||||
|
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set user context
|
||||||
|
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||||
|
|
||||||
req := &v1pb.UpdateIdentityProviderRequest{
|
req := &v1pb.UpdateIdentityProviderRequest{
|
||||||
IdentityProvider: &v1pb.IdentityProvider{
|
IdentityProvider: &v1pb.IdentityProvider{
|
||||||
Name: "invalid-name",
|
Name: "invalid-name",
|
||||||
|
|
@ -379,7 +405,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
|
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||||
})
|
})
|
||||||
|
|
@ -445,11 +471,18 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
||||||
ts := NewTestService(t)
|
ts := NewTestService(t)
|
||||||
defer ts.Cleanup()
|
defer ts.Cleanup()
|
||||||
|
|
||||||
|
// Create host user
|
||||||
|
hostUser, err := ts.CreateHostUser(ctx, "admin")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set user context
|
||||||
|
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
|
||||||
|
|
||||||
req := &v1pb.DeleteIdentityProviderRequest{
|
req := &v1pb.DeleteIdentityProviderRequest{
|
||||||
Name: "invalid-name",
|
Name: "invalid-name",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := ts.Service.DeleteIdentityProvider(ctx, req)
|
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "invalid identity provider name")
|
require.Contains(t, err.Error(), "invalid identity provider name")
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -169,6 +169,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
|
||||||
// Unauthenticated or non-HOST users can only create normal users
|
// Unauthenticated or non-HOST users can only create normal users
|
||||||
roleToAssign = store.RoleUser
|
roleToAssign = store.RoleUser
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only allow user registration if it is enabled in the settings, or if the user is a superuser
|
||||||
|
if currentUser == nil || !isSuperUser(currentUser) {
|
||||||
|
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
|
||||||
|
}
|
||||||
|
if workspaceGeneralSetting.DisallowUserRegistration {
|
||||||
|
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue