2023-02-17 21:06:41 +08:00
|
|
|
package store
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
"github.com/pkg/errors"
|
|
|
|
"google.golang.org/protobuf/encoding/protojson"
|
2023-02-17 21:06:41 +08:00
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
2023-02-17 21:06:41 +08:00
|
|
|
)
|
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
type IdentityProvider struct {
|
2023-08-04 21:55:07 +08:00
|
|
|
ID int32
|
2023-02-17 21:06:41 +08:00
|
|
|
Name string
|
2024-04-13 10:50:25 +08:00
|
|
|
Type storepb.IdentityProvider_Type
|
2023-02-17 21:06:41 +08:00
|
|
|
IdentifierFilter string
|
2024-04-13 10:50:25 +08:00
|
|
|
Config string
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
type FindIdentityProvider struct {
|
2023-08-04 21:55:07 +08:00
|
|
|
ID *int32
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
type UpdateIdentityProvider struct {
|
2023-08-04 21:55:07 +08:00
|
|
|
ID int32
|
2023-02-17 21:06:41 +08:00
|
|
|
Name *string
|
|
|
|
IdentifierFilter *string
|
2024-04-13 10:50:25 +08:00
|
|
|
Config *string
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
type DeleteIdentityProvider struct {
|
2023-08-04 21:55:07 +08:00
|
|
|
ID int32
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
func (s *Store) CreateIdentityProviderV1(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
|
|
|
|
raw, err := convertIdentityProviderToRaw(create)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
|
2023-09-26 19:17:17 +08:00
|
|
|
if err != nil {
|
2023-06-26 23:46:01 +08:00
|
|
|
return nil, err
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
2023-06-26 23:46:01 +08:00
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
|
2023-06-26 23:46:01 +08:00
|
|
|
return identityProvider, nil
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
func (s *Store) ListIdentityProvidersV1(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
|
|
|
|
list, err := s.driver.ListIdentityProviders(ctx, find)
|
2023-02-17 21:06:41 +08:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-07-06 21:56:42 +08:00
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
identityProviders := []*storepb.IdentityProvider{}
|
|
|
|
for _, raw := range list {
|
|
|
|
identityProvider, err := convertIdentityProviderFromRaw(raw)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
|
2023-02-18 18:41:52 +08:00
|
|
|
}
|
2023-07-20 23:15:56 +08:00
|
|
|
return identityProviders, nil
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
func (s *Store) GetIdentityProviderV1(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
|
2023-02-18 18:41:52 +08:00
|
|
|
if find.ID != nil {
|
2024-04-13 10:50:25 +08:00
|
|
|
if cache, ok := s.idpV1Cache.Load(*find.ID); ok {
|
|
|
|
return cache.(*storepb.IdentityProvider), nil
|
2023-02-18 18:41:52 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
list, err := s.ListIdentityProvidersV1(ctx, find)
|
2023-02-17 21:06:41 +08:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if len(list) == 0 {
|
2023-06-26 23:46:01 +08:00
|
|
|
return nil, nil
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
2024-04-13 10:50:25 +08:00
|
|
|
if len(list) > 1 {
|
|
|
|
return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
|
|
|
|
}
|
2023-02-17 21:06:41 +08:00
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
identityProvider := list[0]
|
|
|
|
return identityProvider, nil
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
type UpdateIdentityProviderV1 struct {
|
|
|
|
ID int32
|
|
|
|
Type storepb.IdentityProvider_Type
|
|
|
|
Name *string
|
|
|
|
IdentifierFilter *string
|
|
|
|
Config *storepb.IdentityProviderConfig
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Store) UpdateIdentityProviderV1(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
|
|
|
|
updateRaw := &UpdateIdentityProvider{
|
|
|
|
ID: update.ID,
|
|
|
|
}
|
|
|
|
if update.Name != nil {
|
|
|
|
updateRaw.Name = update.Name
|
|
|
|
}
|
|
|
|
if update.IdentifierFilter != nil {
|
|
|
|
updateRaw.IdentifierFilter = update.IdentifierFilter
|
|
|
|
}
|
|
|
|
if update.Config != nil {
|
|
|
|
configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
updateRaw.Config = &configRaw
|
|
|
|
}
|
|
|
|
identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
|
2023-09-26 19:17:17 +08:00
|
|
|
if err != nil {
|
2023-06-26 23:46:01 +08:00
|
|
|
return nil, err
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
2023-06-26 23:46:01 +08:00
|
|
|
|
2024-04-13 10:50:25 +08:00
|
|
|
identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
|
2023-09-26 19:17:17 +08:00
|
|
|
return identityProvider, nil
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
|
|
|
|
2023-06-26 23:46:01 +08:00
|
|
|
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
|
2023-09-26 19:17:17 +08:00
|
|
|
err := s.driver.DeleteIdentityProvider(ctx, delete)
|
2023-02-17 21:06:41 +08:00
|
|
|
if err != nil {
|
2023-06-26 23:46:01 +08:00
|
|
|
return err
|
2023-02-17 21:06:41 +08:00
|
|
|
}
|
2023-09-26 19:17:17 +08:00
|
|
|
|
2023-02-18 18:41:52 +08:00
|
|
|
s.idpCache.Delete(delete.ID)
|
2023-02-17 21:06:41 +08:00
|
|
|
return nil
|
|
|
|
}
|
2024-04-13 10:50:25 +08:00
|
|
|
|
|
|
|
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
|
|
|
|
identityProvider := &storepb.IdentityProvider{
|
|
|
|
Id: raw.ID,
|
|
|
|
Name: raw.Name,
|
|
|
|
Type: raw.Type,
|
|
|
|
IdentifierFilter: raw.IdentifierFilter,
|
|
|
|
}
|
|
|
|
config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
identityProvider.Config = config
|
|
|
|
return identityProvider, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
|
|
|
|
raw := &IdentityProvider{
|
|
|
|
ID: identityProvider.Id,
|
|
|
|
Name: identityProvider.Name,
|
|
|
|
Type: identityProvider.Type,
|
|
|
|
IdentifierFilter: identityProvider.IdentifierFilter,
|
|
|
|
}
|
|
|
|
configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
raw.Config = configRaw
|
|
|
|
return raw, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
|
|
|
|
config := &storepb.IdentityProviderConfig{}
|
|
|
|
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
|
|
|
|
oauth2Config := &storepb.OAuth2Config{}
|
|
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
|
|
|
|
return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
|
|
|
|
}
|
|
|
|
config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
|
|
|
|
}
|
|
|
|
return config, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
|
|
|
|
raw := ""
|
|
|
|
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
|
|
|
|
bytes, err := protojson.Marshal(config.GetOauth2Config())
|
|
|
|
if err != nil {
|
|
|
|
return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
|
|
|
|
}
|
|
|
|
raw = string(bytes)
|
|
|
|
}
|
|
|
|
return raw, nil
|
|
|
|
}
|