mirror of
				https://github.com/usememos/memos.git
				synced 2025-11-01 01:06:04 +08:00 
			
		
		
		
	refactor: migrate idp to driver
This commit is contained in:
		
							parent
							
								
									63b55c4f65
								
							
						
					
					
						commit
						d68da34eec
					
				
					 3 changed files with 204 additions and 142 deletions
				
			
		|  | @ -26,4 +26,10 @@ type Driver interface { | |||
| 	ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) | ||||
| 	UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) | ||||
| 	ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) | ||||
| 
 | ||||
| 	CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) | ||||
| 	ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) | ||||
| 	GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) | ||||
| 	UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) | ||||
| 	DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error | ||||
| } | ||||
|  |  | |||
							
								
								
									
										150
									
								
								store/idp.go
									
										
									
									
									
								
							
							
						
						
									
										150
									
								
								store/idp.go
									
										
									
									
									
								
							|  | @ -2,11 +2,6 @@ package store | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
| 
 | ||||
| type IdentityProviderType string | ||||
|  | @ -64,98 +59,20 @@ type DeleteIdentityProvider struct { | |||
| } | ||||
| 
 | ||||
| func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { | ||||
| 	var configBytes []byte | ||||
| 	if create.Type == IdentityProviderOAuth2Type { | ||||
| 		bytes, err := json.Marshal(create.Config.OAuth2Config) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		configBytes = bytes | ||||
| 	} else { | ||||
| 		return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt := ` | ||||
| 		INSERT INTO idp ( | ||||
| 			name, | ||||
| 			type, | ||||
| 			identifier_filter, | ||||
| 			config | ||||
| 		) | ||||
| 		VALUES (?, ?, ?, ?) | ||||
| 		RETURNING id | ||||
| 	` | ||||
| 	if err := s.db.QueryRowContext( | ||||
| 		ctx, | ||||
| 		stmt, | ||||
| 		create.Name, | ||||
| 		create.Type, | ||||
| 		create.IdentifierFilter, | ||||
| 		string(configBytes), | ||||
| 	).Scan( | ||||
| 		&create.ID, | ||||
| 	); err != nil { | ||||
| 	identityProvider, err := s.driver.CreateIdentityProvider(ctx, create) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	identityProvider := create | ||||
| 	s.idpCache.Store(identityProvider.ID, identityProvider) | ||||
| 	return identityProvider, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { | ||||
| 	where, args := []string{"1 = 1"}, []any{} | ||||
| 	if v := find.ID; v != nil { | ||||
| 		where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := s.db.QueryContext(ctx, ` | ||||
| 		SELECT | ||||
| 			id, | ||||
| 			name, | ||||
| 			type, | ||||
| 			identifier_filter, | ||||
| 			config | ||||
| 		FROM idp | ||||
| 		WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, | ||||
| 		args..., | ||||
| 	) | ||||
| 	identityProviders, err := s.driver.ListIdentityProviders(ctx, find) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	var identityProviders []*IdentityProvider | ||||
| 	for rows.Next() { | ||||
| 		var identityProvider IdentityProvider | ||||
| 		var identityProviderConfig string | ||||
| 		if err := rows.Scan( | ||||
| 			&identityProvider.ID, | ||||
| 			&identityProvider.Name, | ||||
| 			&identityProvider.Type, | ||||
| 			&identityProvider.IdentifierFilter, | ||||
| 			&identityProviderConfig, | ||||
| 		); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		if identityProvider.Type == IdentityProviderOAuth2Type { | ||||
| 			oauth2Config := &IdentityProviderOAuth2Config{} | ||||
| 			if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			identityProvider.Config = &IdentityProviderConfig{ | ||||
| 				OAuth2Config: oauth2Config, | ||||
| 			} | ||||
| 		} else { | ||||
| 			return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) | ||||
| 		} | ||||
| 		identityProviders = append(identityProviders, &identityProvider) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := rows.Err(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, item := range identityProviders { | ||||
| 		s.idpCache.Store(item.ID, item) | ||||
|  | @ -184,72 +101,21 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi | |||
| } | ||||
| 
 | ||||
| func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { | ||||
| 	set, args := []string{}, []any{} | ||||
| 	if v := update.Name; v != nil { | ||||
| 		set, args = append(set, "name = ?"), append(args, *v) | ||||
| 	} | ||||
| 	if v := update.IdentifierFilter; v != nil { | ||||
| 		set, args = append(set, "identifier_filter = ?"), append(args, *v) | ||||
| 	} | ||||
| 	if v := update.Config; v != nil { | ||||
| 		var configBytes []byte | ||||
| 		if update.Type == IdentityProviderOAuth2Type { | ||||
| 			bytes, err := json.Marshal(update.Config.OAuth2Config) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			configBytes = bytes | ||||
| 		} else { | ||||
| 			return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) | ||||
| 		} | ||||
| 		set, args = append(set, "config = ?"), append(args, string(configBytes)) | ||||
| 	} | ||||
| 	args = append(args, update.ID) | ||||
| 
 | ||||
| 	stmt := ` | ||||
| 		UPDATE idp | ||||
| 		SET ` + strings.Join(set, ", ") + ` | ||||
| 		WHERE id = ? | ||||
| 		RETURNING id, name, type, identifier_filter, config | ||||
| 	` | ||||
| 	var identityProvider IdentityProvider | ||||
| 	var identityProviderConfig string | ||||
| 	if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( | ||||
| 		&identityProvider.ID, | ||||
| 		&identityProvider.Name, | ||||
| 		&identityProvider.Type, | ||||
| 		&identityProvider.IdentifierFilter, | ||||
| 		&identityProviderConfig, | ||||
| 	); err != nil { | ||||
| 	identityProvider, err := s.driver.UpdateIdentityProvider(ctx, update) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if identityProvider.Type == IdentityProviderOAuth2Type { | ||||
| 		oauth2Config := &IdentityProviderOAuth2Config{} | ||||
| 		if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		identityProvider.Config = &IdentityProviderConfig{ | ||||
| 			OAuth2Config: oauth2Config, | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) | ||||
| 	} | ||||
| 
 | ||||
| 	s.idpCache.Store(identityProvider.ID, identityProvider) | ||||
| 	return &identityProvider, nil | ||||
| 	return identityProvider, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { | ||||
| 	where, args := []string{"id = ?"}, []any{delete.ID} | ||||
| 	stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") | ||||
| 	result, err := s.db.ExecContext(ctx, stmt, args...) | ||||
| 	err := s.driver.DeleteIdentityProvider(ctx, delete) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if _, err = result.RowsAffected(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	s.idpCache.Delete(delete.ID) | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
							
								
								
									
										190
									
								
								store/sqlite/idp.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								store/sqlite/idp.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,190 @@ | |||
| package sqlite | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 
 | ||||
| 	"github.com/usememos/memos/store" | ||||
| ) | ||||
| 
 | ||||
| func (d *Driver) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { | ||||
| 	var configBytes []byte | ||||
| 	if create.Type == store.IdentityProviderOAuth2Type { | ||||
| 		bytes, err := json.Marshal(create.Config.OAuth2Config) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		configBytes = bytes | ||||
| 	} else { | ||||
| 		return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) | ||||
| 	} | ||||
| 
 | ||||
| 	stmt := ` | ||||
| 		INSERT INTO idp ( | ||||
| 			name, | ||||
| 			type, | ||||
| 			identifier_filter, | ||||
| 			config | ||||
| 		) | ||||
| 		VALUES (?, ?, ?, ?) | ||||
| 		RETURNING id | ||||
| 	` | ||||
| 	if err := d.db.QueryRowContext( | ||||
| 		ctx, | ||||
| 		stmt, | ||||
| 		create.Name, | ||||
| 		create.Type, | ||||
| 		create.IdentifierFilter, | ||||
| 		string(configBytes), | ||||
| 	).Scan( | ||||
| 		&create.ID, | ||||
| 	); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	identityProvider := create | ||||
| 	return identityProvider, nil | ||||
| } | ||||
| 
 | ||||
| func (d *Driver) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { | ||||
| 	where, args := []string{"1 = 1"}, []any{} | ||||
| 	if v := find.ID; v != nil { | ||||
| 		where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := d.db.QueryContext(ctx, ` | ||||
| 		SELECT | ||||
| 			id, | ||||
| 			name, | ||||
| 			type, | ||||
| 			identifier_filter, | ||||
| 			config | ||||
| 		FROM idp | ||||
| 		WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, | ||||
| 		args..., | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	var identityProviders []*store.IdentityProvider | ||||
| 	for rows.Next() { | ||||
| 		var identityProvider store.IdentityProvider | ||||
| 		var identityProviderConfig string | ||||
| 		if err := rows.Scan( | ||||
| 			&identityProvider.ID, | ||||
| 			&identityProvider.Name, | ||||
| 			&identityProvider.Type, | ||||
| 			&identityProvider.IdentifierFilter, | ||||
| 			&identityProviderConfig, | ||||
| 		); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		if identityProvider.Type == store.IdentityProviderOAuth2Type { | ||||
| 			oauth2Config := &store.IdentityProviderOAuth2Config{} | ||||
| 			if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			identityProvider.Config = &store.IdentityProviderConfig{ | ||||
| 				OAuth2Config: oauth2Config, | ||||
| 			} | ||||
| 		} else { | ||||
| 			return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) | ||||
| 		} | ||||
| 		identityProviders = append(identityProviders, &identityProvider) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := rows.Err(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return identityProviders, nil | ||||
| } | ||||
| 
 | ||||
| func (d *Driver) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) { | ||||
| 	list, err := d.ListIdentityProviders(ctx, find) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(list) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	identityProvider := list[0] | ||||
| 	return identityProvider, nil | ||||
| } | ||||
| 
 | ||||
| func (d *Driver) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { | ||||
| 	set, args := []string{}, []any{} | ||||
| 	if v := update.Name; v != nil { | ||||
| 		set, args = append(set, "name = ?"), append(args, *v) | ||||
| 	} | ||||
| 	if v := update.IdentifierFilter; v != nil { | ||||
| 		set, args = append(set, "identifier_filter = ?"), append(args, *v) | ||||
| 	} | ||||
| 	if v := update.Config; v != nil { | ||||
| 		var configBytes []byte | ||||
| 		if update.Type == store.IdentityProviderOAuth2Type { | ||||
| 			bytes, err := json.Marshal(update.Config.OAuth2Config) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			configBytes = bytes | ||||
| 		} else { | ||||
| 			return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) | ||||
| 		} | ||||
| 		set, args = append(set, "config = ?"), append(args, string(configBytes)) | ||||
| 	} | ||||
| 	args = append(args, update.ID) | ||||
| 
 | ||||
| 	stmt := ` | ||||
| 		UPDATE idp | ||||
| 		SET ` + strings.Join(set, ", ") + ` | ||||
| 		WHERE id = ? | ||||
| 		RETURNING id, name, type, identifier_filter, config | ||||
| 	` | ||||
| 	var identityProvider store.IdentityProvider | ||||
| 	var identityProviderConfig string | ||||
| 	if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( | ||||
| 		&identityProvider.ID, | ||||
| 		&identityProvider.Name, | ||||
| 		&identityProvider.Type, | ||||
| 		&identityProvider.IdentifierFilter, | ||||
| 		&identityProviderConfig, | ||||
| 	); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if identityProvider.Type == store.IdentityProviderOAuth2Type { | ||||
| 		oauth2Config := &store.IdentityProviderOAuth2Config{} | ||||
| 		if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		identityProvider.Config = &store.IdentityProviderConfig{ | ||||
| 			OAuth2Config: oauth2Config, | ||||
| 		} | ||||
| 	} else { | ||||
| 		return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) | ||||
| 	} | ||||
| 
 | ||||
| 	return &identityProvider, nil | ||||
| } | ||||
| 
 | ||||
| func (d *Driver) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { | ||||
| 	where, args := []string{"id = ?"}, []any{delete.ID} | ||||
| 	stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") | ||||
| 	result, err := d.db.ExecContext(ctx, stmt, args...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if _, err = result.RowsAffected(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
		Loading…
	
	Add table
		
		Reference in a new issue