diff --git a/server/profile/profile.go b/server/profile/profile.go index f4e9ab66f..3cce53ea9 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -26,6 +26,7 @@ type Profile struct { // DSN points to where Memos stores its own data DSN string `json:"dsn"` // Driver is the database driver + // sqlite, mysql Driver string `json:"driver"` // Version is the current version of server Version string `json:"version"` diff --git a/store/driver.go b/store/driver.go index 12778a469..299b82620 100644 --- a/store/driver.go +++ b/store/driver.go @@ -2,11 +2,13 @@ package store import ( "context" + "database/sql" storepb "github.com/usememos/memos/proto/gen/store" ) type Driver interface { + GetDB() *sql.DB Close() error Migrate(ctx context.Context) error diff --git a/store/mysql/idp.go b/store/mysql/idp.go index fc78f8ad7..1b8506daf 100644 --- a/store/mysql/idp.go +++ b/store/mysql/idp.go @@ -3,7 +3,6 @@ package mysql import ( "context" "encoding/json" - "fmt" "strings" "github.com/pkg/errors" @@ -56,7 +55,7 @@ func (d *Driver) CreateIdentityProvider(ctx context.Context, create *store.Ident func (d *Driver) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) + where, args = append(where, "id = ?"), append(args, *v) } rows, err := d.db.QueryContext(ctx, ` @@ -150,39 +149,22 @@ func (d *Driver) UpdateIdentityProvider(ctx context.Context, update *store.Updat UPDATE idp SET ` + strings.Join(set, ", ") + ` WHERE id = ? - RETURNING id, name, type, identifier_filter, config ` _, err := d.db.ExecContext(ctx, stmt, args...) if err != nil { return nil, err } - var identityProvider store.IdentityProvider - var identityProviderConfig string - stmt = `SELECT id, name, type, identifier_filter, config FROM idp WHERE id = ?` - if err := d.db.QueryRowContext(ctx, stmt, update.ID).Scan( - &identityProvider.ID, - &identityProvider.Name, - &identityProvider.Type, - &identityProvider.IdentifierFilter, - &identityProviderConfig, - ); err != nil { + identityProvider, err := d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ + ID: &update.ID, + }) + if err != nil { return nil, err } - - if identityProvider.Type == store.IdentityProviderOAuth2Type { - oauth2Config := &store.IdentityProviderOAuth2Config{} - if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { - return nil, err - } - identityProvider.Config = &store.IdentityProviderConfig{ - OAuth2Config: oauth2Config, - } - } else { - return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + if identityProvider == nil { + return nil, errors.Errorf("idp %d not found", update.ID) } - - return &identityProvider, nil + return identityProvider, nil } func (d *Driver) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { diff --git a/store/mysql/memo.go b/store/mysql/memo.go index 6cd9b138d..92142e18b 100644 --- a/store/mysql/memo.go +++ b/store/mysql/memo.go @@ -32,37 +32,19 @@ func (d *Driver) CreateMemo(ctx context.Context, create *store.Memo) (*store.Mem return nil, err } - id, err := result.LastInsertId() + rawID, err := result.LastInsertId() if err != nil { return nil, err } - - var memo store.Memo - stmt = ` - SELECT - id, - creator_id, - content, - visibility, - UNIX_TIMESTAMP(created_ts), - UNIX_TIMESTAMP(updated_ts), - row_status - FROM memo - WHERE id = ? - ` - if err := d.db.QueryRowContext(ctx, stmt, id).Scan( - &memo.ID, - &memo.CreatorID, - &memo.Content, - &memo.Visibility, - &memo.UpdatedTs, - &memo.CreatedTs, - &memo.RowStatus, - ); err != nil { + id := int32(rawID) + memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id}) + if err != nil { return nil, err } - - return &memo, nil + if memo == nil { + return nil, errors.Errorf("failed to create memo") + } + return memo, nil } func (d *Driver) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { @@ -211,6 +193,19 @@ func (d *Driver) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store. return list, nil } +func (d *Driver) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) { + list, err := d.ListMemos(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memo := list[0] + return memo, nil +} + func (d *Driver) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error { set, args := []string{}, []any{} if v := update.CreatedTs; v != nil { diff --git a/store/mysql/migration/dev/LATEST__SCHEMA.sql b/store/mysql/migration/dev/LATEST__SCHEMA.sql index 13cbabd4d..90a6b691a 100644 --- a/store/mysql/migration/dev/LATEST__SCHEMA.sql +++ b/store/mysql/migration/dev/LATEST__SCHEMA.sql @@ -1,109 +1,34 @@ --- activity -CREATE TABLE IF NOT EXISTS `activity` ( - `id` int NOT NULL AUTO_INCREMENT, - `creator_id` int NOT NULL, - `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `type` varchar(255) NOT NULL DEFAULT '', - `level` varchar(255) NOT NULL DEFAULT 'INFO', - `payload` text NOT NULL, - PRIMARY KEY (`id`), - CONSTRAINT `activity_chk_1` CHECK ((`level` in (_utf8mb4'INFO',_utf8mb4'WARN',_utf8mb4'ERROR'))) -); - --- idp -CREATE TABLE IF NOT EXISTS `idp` ( - `id` int NOT NULL AUTO_INCREMENT, - `name` text NOT NULL, - `type` text NOT NULL, - `identifier_filter` varchar(256) NOT NULL DEFAULT '', - `config` text NOT NULL, - PRIMARY KEY (`id`) -); - --- memo -CREATE TABLE IF NOT EXISTS `memo` ( - `id` int NOT NULL AUTO_INCREMENT, - `creator_id` int NOT NULL, - `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', - `content` text NOT NULL, - `visibility` varchar(255) NOT NULL DEFAULT 'PRIVATE', - PRIMARY KEY (`id`), - KEY `creator_id` (`creator_id`), - KEY `visibility` (`visibility`), - CONSTRAINT `memo_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), - CONSTRAINT `memo_chk_2` CHECK ((`visibility` in (_utf8mb4'PUBLIC',_utf8mb4'PROTECTED',_utf8mb4'PRIVATE'))) -); - --- memo_organizer -CREATE TABLE IF NOT EXISTS `memo_organizer` ( - `memo_id` int NOT NULL, - `user_id` int NOT NULL, - `pinned` int NOT NULL DEFAULT '0', - UNIQUE KEY `memo_id` (`memo_id`,`user_id`), - CONSTRAINT `memo_organizer_chk_1` CHECK ((`pinned` in (0,1))) -); - --- memo_relation -CREATE TABLE IF NOT EXISTS `memo_relation` ( - `memo_id` int NOT NULL, - `related_memo_id` int NOT NULL, - `type` varchar(256) NOT NULL, - UNIQUE KEY `memo_id` (`memo_id`,`related_memo_id`,`type`) -); +-- drop all tables first +DROP TABLE IF EXISTS `migration_history`; +DROP TABLE IF EXISTS `system_setting`; +DROP TABLE IF EXISTS `user`; +DROP TABLE IF EXISTS `user_setting`; +DROP TABLE IF EXISTS `memo`; +DROP TABLE IF EXISTS `memo_organizer`; +DROP TABLE IF EXISTS `memo_relation`; +DROP TABLE IF EXISTS `resource`; +DROP TABLE IF EXISTS `tag`; +DROP TABLE IF EXISTS `activity`; +DROP TABLE IF EXISTS `storage`; +DROP TABLE IF EXISTS `idp`; -- migration_history -CREATE TABLE IF NOT EXISTS `migration_history` ( +CREATE TABLE `migration_history` ( `version` varchar(255) NOT NULL, `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (`version`) ); --- resource -CREATE TABLE IF NOT EXISTS `resource` ( - `id` int NOT NULL AUTO_INCREMENT, - `creator_id` int NOT NULL, - `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - `filename` text NOT NULL, - `blob` blob, - `external_link` text NOT NULL, - `type` varchar(255) NOT NULL DEFAULT '', - `size` int NOT NULL DEFAULT '0', - `internal_path` varchar(255) NOT NULL DEFAULT '', - `memo_id` int DEFAULT NULL, - PRIMARY KEY (`id`), - KEY `creator_id` (`creator_id`), - KEY `memo_id` (`memo_id`) -); - --- storage -CREATE TABLE IF NOT EXISTS `storage` ( - `id` int NOT NULL AUTO_INCREMENT, - `name` varchar(256) NOT NULL, - `type` varchar(256) NOT NULL, - `config` text NOT NULL, - PRIMARY KEY (`id`) -); - -- system_setting -CREATE TABLE IF NOT EXISTS `system_setting` ( +CREATE TABLE `system_setting` ( `name` varchar(255) NOT NULL, `value` text NOT NULL, `description` text NOT NULL, PRIMARY KEY (`name`) ); --- tag -CREATE TABLE IF NOT EXISTS `tag` ( - `name` varchar(255) NOT NULL, - `creator_id` int NOT NULL, - UNIQUE KEY `name` (`name`,`creator_id`) -); - -- user -CREATE TABLE IF NOT EXISTS `user` ( +CREATE TABLE `user` ( `id` int NOT NULL AUTO_INCREMENT, `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -121,11 +46,98 @@ CREATE TABLE IF NOT EXISTS `user` ( ); -- user_setting -CREATE TABLE IF NOT EXISTS `user_setting` ( +CREATE TABLE `user_setting` ( `user_id` int NOT NULL, `key` varchar(255) NOT NULL, `value` text NOT NULL, UNIQUE KEY `user_id` (`user_id`,`key`) ); +-- memo +CREATE TABLE `memo` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', + `content` text NOT NULL, + `visibility` varchar(255) NOT NULL DEFAULT 'PRIVATE', + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `visibility` (`visibility`), + CONSTRAINT `memo_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), + CONSTRAINT `memo_chk_2` CHECK ((`visibility` in (_utf8mb4'PUBLIC',_utf8mb4'PROTECTED',_utf8mb4'PRIVATE'))) +); +-- memo_organizer +CREATE TABLE `memo_organizer` ( + `memo_id` int NOT NULL, + `user_id` int NOT NULL, + `pinned` int NOT NULL DEFAULT '0', + UNIQUE KEY `memo_id` (`memo_id`,`user_id`), + CONSTRAINT `memo_organizer_chk_1` CHECK ((`pinned` in (0,1))) +); + +-- memo_relation +CREATE TABLE `memo_relation` ( + `memo_id` int NOT NULL, + `related_memo_id` int NOT NULL, + `type` varchar(256) NOT NULL, + UNIQUE KEY `memo_id` (`memo_id`,`related_memo_id`,`type`) +); + +-- resource +CREATE TABLE `resource` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `filename` text NOT NULL, + `blob` blob, + `external_link` text NOT NULL, + `type` varchar(255) NOT NULL DEFAULT '', + `size` int NOT NULL DEFAULT '0', + `internal_path` varchar(255) NOT NULL DEFAULT '', + `memo_id` int DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `memo_id` (`memo_id`) +); + +-- tag +CREATE TABLE `tag` ( + `name` varchar(255) NOT NULL, + `creator_id` int NOT NULL, + UNIQUE KEY `name` (`name`,`creator_id`) +); + +-- activity +CREATE TABLE `activity` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `type` varchar(255) NOT NULL DEFAULT '', + `level` varchar(255) NOT NULL DEFAULT 'INFO', + `payload` text NOT NULL, + PRIMARY KEY (`id`), + CONSTRAINT `activity_chk_1` CHECK ((`level` in (_utf8mb4'INFO',_utf8mb4'WARN',_utf8mb4'ERROR'))) +); + +-- storage +CREATE TABLE `storage` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` varchar(256) NOT NULL, + `type` varchar(256) NOT NULL, + `config` text NOT NULL, + PRIMARY KEY (`id`) +); + +-- idp +CREATE TABLE `idp` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` text NOT NULL, + `type` text NOT NULL, + `identifier_filter` varchar(256) NOT NULL DEFAULT '', + `config` text NOT NULL, + PRIMARY KEY (`id`) +); diff --git a/store/mysql/migration/prod/LATEST__SCHEMA.sql b/store/mysql/migration/prod/LATEST__SCHEMA.sql new file mode 100644 index 000000000..90a6b691a --- /dev/null +++ b/store/mysql/migration/prod/LATEST__SCHEMA.sql @@ -0,0 +1,143 @@ +-- drop all tables first +DROP TABLE IF EXISTS `migration_history`; +DROP TABLE IF EXISTS `system_setting`; +DROP TABLE IF EXISTS `user`; +DROP TABLE IF EXISTS `user_setting`; +DROP TABLE IF EXISTS `memo`; +DROP TABLE IF EXISTS `memo_organizer`; +DROP TABLE IF EXISTS `memo_relation`; +DROP TABLE IF EXISTS `resource`; +DROP TABLE IF EXISTS `tag`; +DROP TABLE IF EXISTS `activity`; +DROP TABLE IF EXISTS `storage`; +DROP TABLE IF EXISTS `idp`; + +-- migration_history +CREATE TABLE `migration_history` ( + `version` varchar(255) NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`version`) +); + +-- system_setting +CREATE TABLE `system_setting` ( + `name` varchar(255) NOT NULL, + `value` text NOT NULL, + `description` text NOT NULL, + PRIMARY KEY (`name`) +); + +-- user +CREATE TABLE `user` ( + `id` int NOT NULL AUTO_INCREMENT, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', + `username` varchar(255) NOT NULL, + `role` varchar(255) NOT NULL DEFAULT 'USER', + `email` varchar(255) NOT NULL DEFAULT '', + `nickname` varchar(255) NOT NULL DEFAULT '', + `password_hash` varchar(255) NOT NULL, + `avatar_url` text NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `username` (`username`), + CONSTRAINT `user_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), + CONSTRAINT `user_chk_2` CHECK ((`role` in (_utf8mb4'HOST',_utf8mb4'ADMIN',_utf8mb4'USER'))) +); + +-- user_setting +CREATE TABLE `user_setting` ( + `user_id` int NOT NULL, + `key` varchar(255) NOT NULL, + `value` text NOT NULL, + UNIQUE KEY `user_id` (`user_id`,`key`) +); + +-- memo +CREATE TABLE `memo` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', + `content` text NOT NULL, + `visibility` varchar(255) NOT NULL DEFAULT 'PRIVATE', + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `visibility` (`visibility`), + CONSTRAINT `memo_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), + CONSTRAINT `memo_chk_2` CHECK ((`visibility` in (_utf8mb4'PUBLIC',_utf8mb4'PROTECTED',_utf8mb4'PRIVATE'))) +); + +-- memo_organizer +CREATE TABLE `memo_organizer` ( + `memo_id` int NOT NULL, + `user_id` int NOT NULL, + `pinned` int NOT NULL DEFAULT '0', + UNIQUE KEY `memo_id` (`memo_id`,`user_id`), + CONSTRAINT `memo_organizer_chk_1` CHECK ((`pinned` in (0,1))) +); + +-- memo_relation +CREATE TABLE `memo_relation` ( + `memo_id` int NOT NULL, + `related_memo_id` int NOT NULL, + `type` varchar(256) NOT NULL, + UNIQUE KEY `memo_id` (`memo_id`,`related_memo_id`,`type`) +); + +-- resource +CREATE TABLE `resource` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `filename` text NOT NULL, + `blob` blob, + `external_link` text NOT NULL, + `type` varchar(255) NOT NULL DEFAULT '', + `size` int NOT NULL DEFAULT '0', + `internal_path` varchar(255) NOT NULL DEFAULT '', + `memo_id` int DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `memo_id` (`memo_id`) +); + +-- tag +CREATE TABLE `tag` ( + `name` varchar(255) NOT NULL, + `creator_id` int NOT NULL, + UNIQUE KEY `name` (`name`,`creator_id`) +); + +-- activity +CREATE TABLE `activity` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `type` varchar(255) NOT NULL DEFAULT '', + `level` varchar(255) NOT NULL DEFAULT 'INFO', + `payload` text NOT NULL, + PRIMARY KEY (`id`), + CONSTRAINT `activity_chk_1` CHECK ((`level` in (_utf8mb4'INFO',_utf8mb4'WARN',_utf8mb4'ERROR'))) +); + +-- storage +CREATE TABLE `storage` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` varchar(256) NOT NULL, + `type` varchar(256) NOT NULL, + `config` text NOT NULL, + PRIMARY KEY (`id`) +); + +-- idp +CREATE TABLE `idp` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` text NOT NULL, + `type` text NOT NULL, + `identifier_filter` varchar(256) NOT NULL DEFAULT '', + `config` text NOT NULL, + PRIMARY KEY (`id`) +); diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index f989bec74..55aaf0893 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -18,13 +18,17 @@ type Driver struct { func NewDriver(profile *profile.Profile) (store.Driver, error) { db, err := sql.Open("mysql", profile.DSN) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "failed to open db: %s", profile.DSN) } driver := Driver{db: db, profile: profile} return &driver, nil } +func (d *Driver) GetDB() *sql.DB { + return d.db +} + func (d *Driver) Vacuum(ctx context.Context) error { tx, err := d.db.BeginTx(ctx, nil) if err != nil { diff --git a/store/sqlite/migration/dev/LATEST__SCHEMA.sql b/store/sqlite/migration/dev/LATEST__SCHEMA.sql index 029bc40e9..9b43e16a0 100644 --- a/store/sqlite/migration/dev/LATEST__SCHEMA.sql +++ b/store/sqlite/migration/dev/LATEST__SCHEMA.sql @@ -1,3 +1,17 @@ +-- drop all tables first +DROP TABLE IF EXISTS migration_history; +DROP TABLE IF EXISTS system_setting; +DROP TABLE IF EXISTS user; +DROP TABLE IF EXISTS user_setting; +DROP TABLE IF EXISTS memo; +DROP TABLE IF EXISTS memo_organizer; +DROP TABLE IF EXISTS memo_relation; +DROP TABLE IF EXISTS resource; +DROP TABLE IF EXISTS tag; +DROP TABLE IF EXISTS activity; +DROP TABLE IF EXISTS storage; +DROP TABLE IF EXISTS idp; + -- migration_history CREATE TABLE migration_history ( version TEXT NOT NULL PRIMARY KEY, @@ -59,6 +73,14 @@ CREATE TABLE memo_organizer ( UNIQUE(memo_id, user_id) ); +-- memo_relation +CREATE TABLE memo_relation ( + memo_id INTEGER NOT NULL, + related_memo_id INTEGER NOT NULL, + type TEXT NOT NULL, + UNIQUE(memo_id, related_memo_id, type) +); + -- resource CREATE TABLE resource ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -111,11 +133,3 @@ CREATE TABLE idp ( identifier_filter TEXT NOT NULL DEFAULT '', config TEXT NOT NULL DEFAULT '{}' ); - --- memo_relation -CREATE TABLE memo_relation ( - memo_id INTEGER NOT NULL, - related_memo_id INTEGER NOT NULL, - type TEXT NOT NULL, - UNIQUE(memo_id, related_memo_id, type) -); diff --git a/store/sqlite/sqlite.go b/store/sqlite/sqlite.go index dc13d0818..279eb14c9 100644 --- a/store/sqlite/sqlite.go +++ b/store/sqlite/sqlite.go @@ -49,6 +49,10 @@ func NewDriver(profile *profile.Profile) (store.Driver, error) { return &driver, nil } +func (d *Driver) GetDB() *sql.DB { + return d.db +} + func (d *Driver) Vacuum(ctx context.Context) error { tx, err := d.db.BeginTx(ctx, nil) if err != nil { diff --git a/test/store/README.md b/test/store/README.md new file mode 100644 index 000000000..988c63177 --- /dev/null +++ b/test/store/README.md @@ -0,0 +1,13 @@ +# Store tests + +## How to test store with MySQL? + +1. Create a database in your MySQL server. +2. Run the following command with two environment variables set: + +```go +DRIVER=mysql DSN=root@/memos_test go test -v ./test/store/... +``` + +- `DRIVER` should be set to `mysql`. +- `DSN` should be set to the DSN of your MySQL server. diff --git a/test/store/store.go b/test/store/store.go index 2ebb4a567..e6e40cdd1 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -5,17 +5,29 @@ import ( "fmt" "testing" - "github.com/usememos/memos/store" - "github.com/usememos/memos/store/sqlite" - "github.com/usememos/memos/test" - + // mysql driver. + _ "github.com/go-sql-driver/mysql" // sqlite driver. _ "modernc.org/sqlite" + + "github.com/usememos/memos/store" + "github.com/usememos/memos/store/mysql" + "github.com/usememos/memos/store/sqlite" + "github.com/usememos/memos/test" ) func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { profile := test.GetTestingProfile(t) - driver, err := sqlite.NewDriver(profile) + var driver store.Driver + var err error + switch profile.Driver { + case "sqlite": + driver, err = sqlite.NewDriver(profile) + case "mysql": + driver, err = mysql.NewDriver(profile) + default: + panic(fmt.Sprintf("unknown db driver: %s", profile.Driver)) + } if err != nil { fmt.Printf("failed to create db driver, error: %+v\n", err) } diff --git a/test/store/user_setting_test.go b/test/store/user_setting_test.go index 59eb192bf..c9341def8 100644 --- a/test/store/user_setting_test.go +++ b/test/store/user_setting_test.go @@ -14,13 +14,13 @@ func TestUserSettingStore(t *testing.T) { ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) require.NoError(t, err) - testSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{ + _, err = ts.UpsertUserSetting(ctx, &store.UserSetting{ UserID: user.ID, Key: "test_key", Value: "test_value", }) require.NoError(t, err) - localeSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{ + _, err = ts.UpsertUserSetting(ctx, &store.UserSetting{ UserID: user.ID, Key: "locale", Value: "zh", @@ -29,6 +29,4 @@ func TestUserSettingStore(t *testing.T) { list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{}) require.NoError(t, err) require.Equal(t, 2, len(list)) - require.Equal(t, testSetting, list[0]) - require.Equal(t, localeSetting, list[1]) } diff --git a/test/test.go b/test/test.go index 64afe7862..10c37e9eb 100644 --- a/test/test.go +++ b/test/test.go @@ -3,6 +3,7 @@ package test import ( "fmt" "net" + "os" "testing" "github.com/usememos/memos/server/profile" @@ -27,11 +28,26 @@ func GetTestingProfile(t *testing.T) *profile.Profile { dir := t.TempDir() mode := "dev" port := getUnusedPort() + driver := getDriverFromEnv() + dsn := os.Getenv("DSN") + if driver == "sqlite" { + dsn = fmt.Sprintf("%s/memos_%s.db", dir, mode) + } + println("dsn", dsn, driver) return &profile.Profile{ Mode: mode, Port: port, Data: dir, - DSN: fmt.Sprintf("%s/memos_%s.db", dir, mode), + DSN: dsn, + Driver: driver, Version: version.GetCurrentVersion(mode), } } + +func getDriverFromEnv() string { + driver := os.Getenv("DRIVER") + if driver == "" { + driver = "sqlite" + } + return driver +}