memos/store/idp.go

193 lines
5.3 KiB
Go
Raw Normal View History

package store
import (
"context"
2024-04-13 10:50:25 +08:00
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
2024-04-13 10:50:25 +08:00
storepb "github.com/usememos/memos/proto/gen/store"
)
2023-06-26 23:46:01 +08:00
type IdentityProvider struct {
2023-08-04 21:55:07 +08:00
ID int32
Name string
2024-04-13 10:50:25 +08:00
Type storepb.IdentityProvider_Type
IdentifierFilter string
2024-04-13 10:50:25 +08:00
Config string
}
2023-06-26 23:46:01 +08:00
type FindIdentityProvider struct {
2023-08-04 21:55:07 +08:00
ID *int32
}
2023-06-26 23:46:01 +08:00
type UpdateIdentityProvider struct {
2023-08-04 21:55:07 +08:00
ID int32
Name *string
IdentifierFilter *string
2024-04-13 10:50:25 +08:00
Config *string
}
2023-06-26 23:46:01 +08:00
type DeleteIdentityProvider struct {
2023-08-04 21:55:07 +08:00
ID int32
}
2024-04-13 10:50:25 +08:00
func (s *Store) CreateIdentityProviderV1(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
raw, err := convertIdentityProviderToRaw(create)
if err != nil {
return nil, err
}
identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
2023-09-26 19:17:17 +08:00
if err != nil {
2023-06-26 23:46:01 +08:00
return nil, err
}
2023-06-26 23:46:01 +08:00
2024-04-13 10:50:25 +08:00
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
2023-06-26 23:46:01 +08:00
return identityProvider, nil
}
2024-04-13 10:50:25 +08:00
func (s *Store) ListIdentityProvidersV1(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
list, err := s.driver.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
}
2024-04-13 10:50:25 +08:00
identityProviders := []*storepb.IdentityProvider{}
for _, raw := range list {
identityProvider, err := convertIdentityProviderFromRaw(raw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
}
return identityProviders, nil
}
2024-04-13 10:50:25 +08:00
func (s *Store) GetIdentityProviderV1(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
if find.ID != nil {
2024-04-13 10:50:25 +08:00
if cache, ok := s.idpV1Cache.Load(*find.ID); ok {
return cache.(*storepb.IdentityProvider), nil
}
}
2024-04-13 10:50:25 +08:00
list, err := s.ListIdentityProvidersV1(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
2023-06-26 23:46:01 +08:00
return nil, nil
}
2024-04-13 10:50:25 +08:00
if len(list) > 1 {
return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
}
2023-06-26 23:46:01 +08:00
identityProvider := list[0]
return identityProvider, nil
}
2024-04-13 10:50:25 +08:00
type UpdateIdentityProviderV1 struct {
ID int32
Type storepb.IdentityProvider_Type
Name *string
IdentifierFilter *string
Config *storepb.IdentityProviderConfig
}
func (s *Store) UpdateIdentityProviderV1(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
updateRaw := &UpdateIdentityProvider{
ID: update.ID,
}
if update.Name != nil {
updateRaw.Name = update.Name
}
if update.IdentifierFilter != nil {
updateRaw.IdentifierFilter = update.IdentifierFilter
}
if update.Config != nil {
configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
if err != nil {
return nil, err
}
updateRaw.Config = &configRaw
}
identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
2023-09-26 19:17:17 +08:00
if err != nil {
2023-06-26 23:46:01 +08:00
return nil, err
}
2023-06-26 23:46:01 +08:00
2024-04-13 10:50:25 +08:00
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
2023-09-26 19:17:17 +08:00
return identityProvider, nil
}
2023-06-26 23:46:01 +08:00
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
2023-09-26 19:17:17 +08:00
err := s.driver.DeleteIdentityProvider(ctx, delete)
if err != nil {
2023-06-26 23:46:01 +08:00
return err
}
2023-09-26 19:17:17 +08:00
s.idpCache.Delete(delete.ID)
return nil
}
2024-04-13 10:50:25 +08:00
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
identityProvider := &storepb.IdentityProvider{
Id: raw.ID,
Name: raw.Name,
Type: raw.Type,
IdentifierFilter: raw.IdentifierFilter,
}
config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
if err != nil {
return nil, err
}
identityProvider.Config = config
return identityProvider, nil
}
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
raw := &IdentityProvider{
ID: identityProvider.Id,
Name: identityProvider.Name,
Type: identityProvider.Type,
IdentifierFilter: identityProvider.IdentifierFilter,
}
configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
if err != nil {
return nil, err
}
raw.Config = configRaw
return raw, nil
}
func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
config := &storepb.IdentityProviderConfig{}
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
oauth2Config := &storepb.OAuth2Config{}
if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
}
config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
}
return config, nil
}
func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
raw := ""
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
bytes, err := protojson.Marshal(config.GetOauth2Config())
if err != nil {
return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
}
raw = string(bytes)
}
return raw, nil
}