diff --git a/store/db/migration/dev/LATEST__SCHEMA.sql b/store/db/migration/dev/LATEST__SCHEMA.sql index 8f3db83f..268018d7 100644 --- a/store/db/migration/dev/LATEST__SCHEMA.sql +++ b/store/db/migration/dev/LATEST__SCHEMA.sql @@ -114,4 +114,13 @@ CREATE TABLE storage ( secret_key TEXT NOT NULL DEFAULT '', bucket TEXT NOT NULL DEFAULT '', url_prefix TEXT NOT NULL DEFAULT '' +); + +-- idp +CREATE TABLE idp ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL DEFAULT '', + type TEXT NOT NULL, + identifier_filter TEXT NOT NULL DEFAULT '', + config TEXT NOT NULL DEFAULT '{}' ); \ No newline at end of file diff --git a/store/idp.go b/store/idp.go new file mode 100644 index 00000000..f64e38d5 --- /dev/null +++ b/store/idp.go @@ -0,0 +1,265 @@ +package store + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/usememos/memos/common" +) + +type IdentityProvideType string + +const ( + IdentityProviderOAuth2 IdentityProvideType = "OAUTH2" +) + +type IdentityProviderConfig interface{} + +type IdentityProviderOAuth2Config struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AuthURL string `json:"authUrl"` + TokenURL string `json:"tokenUrl"` + UserInfoURL string `json:"userInfoUrl"` + Scopes []string `json:"scopes"` + FieldMapping *FieldMapping `json:"fieldMapping"` +} + +type FieldMapping struct { + Identifier string + DisplayName string + Email string +} + +type IdentityProviderMessage struct { + ID int + Name string + Type IdentityProvideType + IdentifierFilter string + Config *IdentityProviderConfig +} + +type FindIdentityProviderMessage struct { + ID *int +} + +type UpdateIdentityProviderMessage struct { + ID int + Type IdentityProvideType + Name *string + IdentifierFilter *string + Config *IdentityProviderConfig +} + +type DeleteIdentityProviderMessage struct { + ID int +} + +func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + var configBytes []byte + if create.Type == IdentityProviderOAuth2 { + configBytes, err = json.Marshal(any(create.Config).(*IdentityProviderOAuth2Config)) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(create.Type)) + } + query := ` + INSERT INTO idp ( + name, + type, + identifier_filter, + config + ) + VALUES (?, ?, ?, ?) + RETURNING id + ` + if err := tx.QueryRowContext( + ctx, + query, + create.Name, + create.Type, + create.IdentifierFilter, + string(configBytes), + ).Scan( + &create.ID, + ); err != nil { + return nil, FormatError(err) + } + + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + return create, nil +} + +func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := listIdentityProviders(ctx, tx, find) + if err != nil { + return nil, err + } + + return list, nil +} + +func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := listIdentityProviders(ctx, tx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + } + + return list[0], nil +} + +func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + set, args := []string{}, []interface{}{} + if v := update.Name; v != nil { + set, args = append(set, "name = ?"), append(args, *v) + } + if v := update.IdentifierFilter; v != nil { + set, args = append(set, "identifier_filter = ?"), append(args, *v) + } + if v := update.Config; v != nil { + var configBytes []byte + if update.Type == IdentityProviderOAuth2 { + configBytes, err = json.Marshal(any(update.Config).(*IdentityProviderOAuth2Config)) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(update.Type)) + } + set, args = append(set, "config = ?"), append(args, string(configBytes)) + } + args = append(args, update.ID) + + query := ` + UPDATE idp + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, name, type, identifier_filter, config + ` + var identityProviderMessage IdentityProviderMessage + var identityProviderConfig string + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &identityProviderMessage.ID, + &identityProviderMessage.Name, + &identityProviderMessage.Type, + &identityProviderMessage.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, FormatError(err) + } + if identityProviderMessage.Type == IdentityProviderOAuth2 { + if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) + } + + return &identityProviderMessage, nil +} + +func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + where, args := []string{"id = ?"}, []interface{}{delete.ID} + stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") + result, err := tx.ExecContext(ctx, stmt, args...) + if err != nil { + return FormatError(err) + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")} + } + return nil +} + +func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) { + where, args := []string{"TRUE"}, []interface{}{} + if v := find.ID; v != nil { + where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) + } + + rows, err := tx.QueryContext(ctx, ` + SELECT + id, + name, + type, + identifier_filter, + config + FROM idp + WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, + args..., + ) + if err != nil { + return nil, FormatError(err) + } + defer rows.Close() + + var identityProviderMessages []*IdentityProviderMessage + for rows.Next() { + var identityProviderMessage IdentityProviderMessage + var identityProviderConfig string + if err := rows.Scan( + &identityProviderMessage.ID, + &identityProviderMessage.Name, + &identityProviderMessage.Type, + &identityProviderMessage.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, FormatError(err) + } + if identityProviderMessage.Type == IdentityProviderOAuth2 { + if err := json.Unmarshal([]byte(identityProviderConfig), any(identityProviderMessage.Config).(*IdentityProviderOAuth2Config)); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) + } + identityProviderMessages = append(identityProviderMessages, &identityProviderMessage) + } + + return identityProviderMessages, nil +}