diff --git a/server/router/api/v1/idp_service.go b/server/router/api/v1/idp_service.go index 384119b84..51b44151f 100644 --- a/server/router/api/v1/idp_service.go +++ b/server/router/api/v1/idp_service.go @@ -38,8 +38,17 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId response := &v1pb.ListIdentityProvidersResponse{ IdentityProviders: []*v1pb.IdentityProvider{}, } + + // Default to lowest-privilege role, update later based on real role + currentUserRole := store.RoleUser + currentUser, err := s.GetCurrentUser(ctx) + if err == nil && currentUser != nil { + currentUserRole = currentUser.Role + } + for _, identityProvider := range identityProviders { - response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider)) + identityProviderConverted := convertIdentityProviderFromStore(identityProvider) + response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole)) } return response, nil } @@ -58,10 +67,27 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge if identityProvider == nil { return nil, status.Errorf(codes.NotFound, "identity provider not found") } - return convertIdentityProviderFromStore(identityProvider), nil + + // Default to lowest-privilege role, update later based on real role + currentUserRole := store.RoleUser + currentUser, err := s.GetCurrentUser(ctx) + if err == nil && currentUser != nil { + currentUserRole = currentUser.Role + } + + identityProviderConverted := convertIdentityProviderFromStore(identityProvider) + return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil } func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) { + currentUser, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) + } + if currentUser == nil || currentUser.Role != store.RoleHost { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 { return nil, status.Errorf(codes.InvalidArgument, "update_mask is required") } @@ -95,6 +121,14 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb } func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) { + currentUser, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) + } + if currentUser == nil || currentUser.Role != store.RoleHost { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + id, err := ExtractIdentityProviderIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) @@ -183,3 +217,13 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv } return nil } + +func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider { + if userRole != store.RoleHost { + if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 { + identityProvider.Config.GetOauth2Config().ClientSecret = "" + } + } + + return identityProvider +} diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index e7c7f18ac..4084c9a8a 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -14,6 +14,13 @@ import ( ) func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) { + user, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } memoUID, err := ExtractMemoUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) @@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if memo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ MemoID: &memo.ID, }) diff --git a/server/router/api/v1/memo_relation_service.go b/server/router/api/v1/memo_relation_service.go index 247f85ac8..77cff1a38 100644 --- a/server/router/api/v1/memo_relation_service.go +++ b/server/router/api/v1/memo_relation_service.go @@ -14,6 +14,13 @@ import ( ) func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) { + user, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } memoUID, err := ExtractMemoUIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) @@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if memo.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } referenceType := store.MemoRelationReference // Delete all reference relations first. if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ diff --git a/server/router/api/v1/reaction_service.go b/server/router/api/v1/reaction_service.go index 7dd007d8f..f5ec6d96d 100644 --- a/server/router/api/v1/reaction_service.go +++ b/server/router/api/v1/reaction_service.go @@ -55,11 +55,35 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups } func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) { + user, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + reactionID, err := ExtractReactionIDFromName(request.Name) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err) } + // Get reaction and check ownership + reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ + ID: &reactionID, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list reactions") + } + if len(reactions) == 0 { + return nil, status.Errorf(codes.NotFound, "reaction not found") + } + + reaction := reactions[0] + if reaction.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{ ID: reactionID, }); err != nil { diff --git a/server/router/api/v1/test/idp_service_test.go b/server/router/api/v1/test/idp_service_test.go index 4b7cec763..d60d42de0 100644 --- a/server/router/api/v1/test/idp_service_test.go +++ b/server/router/api/v1/test/idp_service_test.go @@ -233,6 +233,7 @@ func TestGetIdentityProvider(t *testing.T) { Name: created.Name, } + // Test unauthenticated, should not contain client secret resp, err := ts.Service.GetIdentityProvider(ctx, getReq) require.NoError(t, err) require.NotNil(t, resp) @@ -241,7 +242,18 @@ func TestGetIdentityProvider(t *testing.T) { require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type) require.NotNil(t, resp.Config.GetOauth2Config()) require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId) - require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret) + require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret) + + // Test as host user, should contain client secret + respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq) + require.NoError(t, err) + require.NotNil(t, respHostUser) + require.Equal(t, created.Name, respHostUser.Name) + require.Equal(t, "Test Provider", respHostUser.Title) + require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type) + require.NotNil(t, respHostUser.Config.GetOauth2Config()) + require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId) + require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret) }) t.Run("GetIdentityProvider not found", func(t *testing.T) { @@ -353,6 +365,13 @@ func TestUpdateIdentityProvider(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.ID) + req := &v1pb.UpdateIdentityProviderRequest{ IdentityProvider: &v1pb.IdentityProvider{ Name: "identity-providers/1", @@ -360,7 +379,7 @@ func TestUpdateIdentityProvider(t *testing.T) { }, } - _, err := ts.Service.UpdateIdentityProvider(ctx, req) + _, err = ts.Service.UpdateIdentityProvider(userCtx, req) require.Error(t, err) require.Contains(t, err.Error(), "update_mask is required") }) @@ -369,6 +388,13 @@ func TestUpdateIdentityProvider(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.ID) + req := &v1pb.UpdateIdentityProviderRequest{ IdentityProvider: &v1pb.IdentityProvider{ Name: "invalid-name", @@ -379,7 +405,7 @@ func TestUpdateIdentityProvider(t *testing.T) { }, } - _, err := ts.Service.UpdateIdentityProvider(ctx, req) + _, err = ts.Service.UpdateIdentityProvider(userCtx, req) require.Error(t, err) require.Contains(t, err.Error(), "invalid identity provider name") }) @@ -445,11 +471,18 @@ func TestDeleteIdentityProvider(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.ID) + req := &v1pb.DeleteIdentityProviderRequest{ Name: "invalid-name", } - _, err := ts.Service.DeleteIdentityProvider(ctx, req) + _, err = ts.Service.DeleteIdentityProvider(userCtx, req) require.Error(t, err) require.Contains(t, err.Error(), "invalid identity provider name") }) diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 2b746e8ce..e5de08db3 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -169,6 +169,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR // Unauthenticated or non-HOST users can only create normal users roleToAssign = store.RoleUser } + + // Only allow user registration if it is enabled in the settings, or if the user is a superuser + if currentUser == nil || !isSuperUser(currentUser) { + workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) + } + if workspaceGeneralSetting.DisallowUserRegistration { + return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed") + } + } } if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {