mirror of
https://github.com/usememos/memos.git
synced 2025-02-01 01:58:14 +08:00
feat: introduce idp table (#1105)
* feat: introduce idp table * chore: update
This commit is contained in:
parent
c6e525b06f
commit
a538b9789b
2 changed files with 274 additions and 0 deletions
|
@ -114,4 +114,13 @@ CREATE TABLE storage (
|
|||
secret_key TEXT NOT NULL DEFAULT '',
|
||||
bucket TEXT NOT NULL DEFAULT '',
|
||||
url_prefix TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- idp
|
||||
CREATE TABLE idp (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
type TEXT NOT NULL,
|
||||
identifier_filter TEXT NOT NULL DEFAULT '',
|
||||
config TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
265
store/idp.go
Normal file
265
store/idp.go
Normal file
|
@ -0,0 +1,265 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/common"
|
||||
)
|
||||
|
||||
type IdentityProvideType string
|
||||
|
||||
const (
|
||||
IdentityProviderOAuth2 IdentityProvideType = "OAUTH2"
|
||||
)
|
||||
|
||||
type IdentityProviderConfig interface{}
|
||||
|
||||
type IdentityProviderOAuth2Config struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
AuthURL string `json:"authUrl"`
|
||||
TokenURL string `json:"tokenUrl"`
|
||||
UserInfoURL string `json:"userInfoUrl"`
|
||||
Scopes []string `json:"scopes"`
|
||||
FieldMapping *FieldMapping `json:"fieldMapping"`
|
||||
}
|
||||
|
||||
type FieldMapping struct {
|
||||
Identifier string
|
||||
DisplayName string
|
||||
Email string
|
||||
}
|
||||
|
||||
type IdentityProviderMessage struct {
|
||||
ID int
|
||||
Name string
|
||||
Type IdentityProvideType
|
||||
IdentifierFilter string
|
||||
Config *IdentityProviderConfig
|
||||
}
|
||||
|
||||
type FindIdentityProviderMessage struct {
|
||||
ID *int
|
||||
}
|
||||
|
||||
type UpdateIdentityProviderMessage struct {
|
||||
ID int
|
||||
Type IdentityProvideType
|
||||
Name *string
|
||||
IdentifierFilter *string
|
||||
Config *IdentityProviderConfig
|
||||
}
|
||||
|
||||
type DeleteIdentityProviderMessage struct {
|
||||
ID int
|
||||
}
|
||||
|
||||
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var configBytes []byte
|
||||
if create.Type == IdentityProviderOAuth2 {
|
||||
configBytes, err = json.Marshal(any(create.Config).(*IdentityProviderOAuth2Config))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
|
||||
}
|
||||
query := `
|
||||
INSERT INTO idp (
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id
|
||||
`
|
||||
if err := tx.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
create.Name,
|
||||
create.Type,
|
||||
create.IdentifierFilter,
|
||||
string(configBytes),
|
||||
).Scan(
|
||||
&create.ID,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
list, err := listIdentityProviders(ctx, tx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
list, err := listIdentityProviders(ctx, tx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
set, args := []string{}, []interface{}{}
|
||||
if v := update.Name; v != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
set, args = append(set, "identifier_filter = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
var configBytes []byte
|
||||
if update.Type == IdentityProviderOAuth2 {
|
||||
configBytes, err = json.Marshal(any(update.Config).(*IdentityProviderOAuth2Config))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
|
||||
}
|
||||
set, args = append(set, "config = ?"), append(args, string(configBytes))
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
query := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
var identityProviderMessage IdentityProviderMessage
|
||||
var identityProviderConfig string
|
||||
if err := tx.QueryRowContext(ctx, query, args...).Scan(
|
||||
&identityProviderMessage.ID,
|
||||
&identityProviderMessage.Name,
|
||||
&identityProviderMessage.Type,
|
||||
&identityProviderMessage.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
if identityProviderMessage.Type == IdentityProviderOAuth2 {
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
|
||||
}
|
||||
|
||||
return &identityProviderMessage, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
where, args := []string{"id = ?"}, []interface{}{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := tx.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return FormatError(err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
|
||||
where, args := []string{"TRUE"}, []interface{}{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var identityProviderMessages []*IdentityProviderMessage
|
||||
for rows.Next() {
|
||||
var identityProviderMessage IdentityProviderMessage
|
||||
var identityProviderConfig string
|
||||
if err := rows.Scan(
|
||||
&identityProviderMessage.ID,
|
||||
&identityProviderMessage.Name,
|
||||
&identityProviderMessage.Type,
|
||||
&identityProviderMessage.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
if identityProviderMessage.Type == IdentityProviderOAuth2 {
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
|
||||
}
|
||||
identityProviderMessages = append(identityProviderMessages, &identityProviderMessage)
|
||||
}
|
||||
|
||||
return identityProviderMessages, nil
|
||||
}
|
Loading…
Reference in a new issue