memos/server/route/api/v2/idp_service.go

170 lines
6.7 KiB
Go
Raw Normal View History

2024-04-13 10:50:25 +08:00
package v2
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
2024-04-27 22:02:15 +08:00
"google.golang.org/protobuf/types/known/emptypb"
2024-04-13 10:50:25 +08:00
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
2024-04-27 22:02:15 +08:00
func (s *APIV2Service) CreateIdentityProvider(ctx context.Context, request *apiv2pb.CreateIdentityProviderRequest) (*apiv2pb.IdentityProvider, error) {
2024-04-13 10:50:25 +08:00
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
2024-04-17 08:56:52 +08:00
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
2024-04-13 10:50:25 +08:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
2024-04-27 22:02:15 +08:00
return convertIdentityProviderFromStore(identityProvider), nil
2024-04-13 10:50:25 +08:00
}
func (s *APIV2Service) ListIdentityProviders(ctx context.Context, _ *apiv2pb.ListIdentityProvidersRequest) (*apiv2pb.ListIdentityProvidersResponse, error) {
2024-04-17 08:56:52 +08:00
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
2024-04-13 10:50:25 +08:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &apiv2pb.ListIdentityProvidersResponse{
IdentityProviders: []*apiv2pb.IdentityProvider{},
}
for _, identityProvider := range identityProviders {
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
}
return response, nil
}
2024-04-27 22:02:15 +08:00
func (s *APIV2Service) GetIdentityProvider(ctx context.Context, request *apiv2pb.GetIdentityProviderRequest) (*apiv2pb.IdentityProvider, error) {
2024-04-13 10:50:25 +08:00
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
2024-04-17 08:56:52 +08:00
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
2024-04-13 10:50:25 +08:00
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
2024-04-27 22:02:15 +08:00
return convertIdentityProviderFromStore(identityProvider), nil
2024-04-13 10:50:25 +08:00
}
2024-04-27 22:02:15 +08:00
func (s *APIV2Service) UpdateIdentityProvider(ctx context.Context, request *apiv2pb.UpdateIdentityProviderRequest) (*apiv2pb.IdentityProvider, error) {
2024-04-13 10:50:25 +08:00
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
}
}
2024-04-17 08:56:52 +08:00
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
2024-04-13 10:50:25 +08:00
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
2024-04-27 22:02:15 +08:00
return convertIdentityProviderFromStore(identityProvider), nil
2024-04-13 10:50:25 +08:00
}
2024-04-27 22:02:15 +08:00
func (s *APIV2Service) DeleteIdentityProvider(ctx context.Context, request *apiv2pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
2024-04-13 10:50:25 +08:00
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
2024-04-27 22:02:15 +08:00
return &emptypb.Empty{}, nil
2024-04-13 10:50:25 +08:00
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *apiv2pb.IdentityProvider {
temp := &apiv2pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: apiv2pb.IdentityProvider_Type(apiv2pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &apiv2pb.IdentityProviderConfig{
Config: &apiv2pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &apiv2pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &apiv2pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *apiv2pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType apiv2pb.IdentityProvider_Type, config *apiv2pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == apiv2pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return nil
}