mirror of
				https://github.com/usememos/memos.git
				synced 2025-10-26 22:36:16 +08:00 
			
		
		
		
	chore: add tests for migrator
This commit is contained in:
		
							parent
							
								
									96b9269cd3
								
							
						
					
					
						commit
						525223c261
					
				
					 8 changed files with 147 additions and 108 deletions
				
			
		|  | @ -29,11 +29,6 @@ func GetMinorVersion(version string) string { | |||
| 	return versionList[0] + "." + versionList[1] | ||||
| } | ||||
| 
 | ||||
| func GetSchemaVersion(version string) string { | ||||
| 	minorVersion := GetMinorVersion(version) | ||||
| 	return minorVersion + ".0" | ||||
| } | ||||
| 
 | ||||
| // IsVersionGreaterOrEqualThan returns true if version is greater than or equal to target. | ||||
| func IsVersionGreaterOrEqualThan(version, target string) bool { | ||||
| 	return semver.Compare(fmt.Sprintf("v%s", version), fmt.Sprintf("v%s", target)) > -1 | ||||
|  |  | |||
|  | @ -39,10 +39,6 @@ func NewDB(profile *profile.Profile) (store.Driver, error) { | |||
| 	return &driver, nil | ||||
| } | ||||
| 
 | ||||
| func (*DB) Type() string { | ||||
| 	return "mysql" | ||||
| } | ||||
| 
 | ||||
| func (d *DB) GetDB() *sql.DB { | ||||
| 	return d.db | ||||
| } | ||||
|  |  | |||
|  | @ -39,10 +39,6 @@ func NewDB(profile *profile.Profile) (store.Driver, error) { | |||
| 	return driver, nil | ||||
| } | ||||
| 
 | ||||
| func (*DB) Type() string { | ||||
| 	return "postgres" | ||||
| } | ||||
| 
 | ||||
| func (d *DB) GetDB() *sql.DB { | ||||
| 	return d.db | ||||
| } | ||||
|  |  | |||
|  | @ -50,10 +50,6 @@ func NewDB(profile *profile.Profile) (store.Driver, error) { | |||
| 	return &driver, nil | ||||
| } | ||||
| 
 | ||||
| func (*DB) Type() string { | ||||
| 	return "sqlite" | ||||
| } | ||||
| 
 | ||||
| func (d *DB) GetDB() *sql.DB { | ||||
| 	return d.db | ||||
| } | ||||
|  |  | |||
|  | @ -11,10 +11,6 @@ type Driver interface { | |||
| 	GetDB() *sql.DB | ||||
| 	Close() error | ||||
| 
 | ||||
| 	// Type returns the type of the driver. | ||||
| 	// Supported types are: sqlite, mysql, postgres. | ||||
| 	Type() string | ||||
| 
 | ||||
| 	// MigrationHistory model related methods. | ||||
| 	FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error) | ||||
| 	UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) | ||||
|  |  | |||
|  | @ -2,18 +2,30 @@ package store | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"embed" | ||||
| 	"fmt" | ||||
| 	"io/fs" | ||||
| 	"log/slog" | ||||
| 	"regexp" | ||||
| 	"path/filepath" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 
 | ||||
| 	"github.com/usememos/memos/server/version" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// MIGRATE_FILE_NAME_SPLIT is the split character between the patch version and the description in the migration file name. | ||||
| 	// For example, "1__create_table.sql". | ||||
| 	MIGRATE_FILE_NAME_SPLIT = "__" | ||||
| 	// LATEST_SCHEMA_FILE_NAME is the name of the latest schema file. | ||||
| 	// This file is used to apply the latest schema when no migration history is found. | ||||
| 	LATEST_SCHEMA_FILE_NAME = "LATEST__SCHEMA.sql" | ||||
| ) | ||||
| 
 | ||||
| //go:embed migration | ||||
| var migrationFS embed.FS | ||||
| 
 | ||||
|  | @ -41,21 +53,54 @@ func (s *Store) Migrate(ctx context.Context) error { | |||
| 		} | ||||
| 		sort.Sort(version.SortVersion(migrationHistoryVersions)) | ||||
| 		latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1] | ||||
| 		currentVersion := version.GetCurrentVersion(s.Profile.Mode) | ||||
| 		schemaVersion, err := s.GetCurrentSchemaVersion() | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "failed to get current schema version") | ||||
| 		} | ||||
| 
 | ||||
| 		if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) { | ||||
| 			filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s/*/*.sql", s.getMigrationBasePath())) | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "failed to read migration files") | ||||
| 			} | ||||
| 			sort.Strings(filePaths) | ||||
| 
 | ||||
| 			// Start a transaction to apply the latest schema. | ||||
| 			tx, err := s.driver.GetDB().Begin() | ||||
| 			if err != nil { | ||||
| 				return errors.Wrap(err, "failed to start transaction") | ||||
| 			} | ||||
| 			defer tx.Rollback() | ||||
| 
 | ||||
| 		if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { | ||||
| 			minorVersionList := s.getMinorVersionList() | ||||
| 			fmt.Println("start migration") | ||||
| 			for _, minorVersion := range minorVersionList { | ||||
| 				normalizedVersion := minorVersion + ".0" | ||||
| 				if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { | ||||
| 					fmt.Println("applying migration for", normalizedVersion) | ||||
| 					if err := s.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { | ||||
| 						return errors.Wrap(err, "failed to apply minor version migration") | ||||
| 			for _, filePath := range filePaths { | ||||
| 				fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) | ||||
| 				if err != nil { | ||||
| 					return errors.Wrap(err, "failed to get schema version of migrate script") | ||||
| 				} | ||||
| 				if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) { | ||||
| 					bytes, err := migrationFS.ReadFile(filePath) | ||||
| 					if err != nil { | ||||
| 						return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath) | ||||
| 					} | ||||
| 					stmt := string(bytes) | ||||
| 					if err := s.execute(ctx, tx, stmt); err != nil { | ||||
| 						return errors.Wrapf(err, "migrate error: %s", stmt) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if err := tx.Commit(); err != nil { | ||||
| 				return errors.Wrap(err, "failed to commit transaction") | ||||
| 			} | ||||
| 			fmt.Println("end migrate") | ||||
| 
 | ||||
| 			// Upsert the current schema version to migration_history. | ||||
| 			if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ | ||||
| 				Version: schemaVersion, | ||||
| 			}); err != nil { | ||||
| 				return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion) | ||||
| 			} | ||||
| 		} | ||||
| 	} else if s.Profile.Mode == "demo" { | ||||
| 		// In demo mode, we should seed the database. | ||||
|  | @ -68,21 +113,36 @@ func (s *Store) Migrate(ctx context.Context) error { | |||
| 
 | ||||
| func (s *Store) preMigrate(ctx context.Context) error { | ||||
| 	migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) | ||||
| 	// If there is no migration history, we should apply the latest schema. | ||||
| 	// If any error occurs or no migration history found, apply the latest schema. | ||||
| 	if err != nil || len(migrationHistoryList) == 0 { | ||||
| 		if err != nil { | ||||
| 			slog.Error("failed to find migration history", "error", err) | ||||
| 			slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error())) | ||||
| 		} | ||||
| 		fileName := s.getMigrationBasePath() + latestSchemaFileName | ||||
| 		bytes, err := migrationFS.ReadFile(fileName) | ||||
| 		filePath := s.getMigrationBasePath() + LATEST_SCHEMA_FILE_NAME | ||||
| 		bytes, err := migrationFS.ReadFile(filePath) | ||||
| 		if err != nil { | ||||
| 			return errors.Errorf("failed to read latest schema file: %s", err) | ||||
| 		} | ||||
| 		if err := s.execute(ctx, string(bytes)); err != nil { | ||||
| 			return errors.Errorf("failed to exec SQL file %s, err %s", fileName, err) | ||||
| 		schemaVersion, err := s.GetCurrentSchemaVersion() | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "failed to get current schema version") | ||||
| 		} | ||||
| 
 | ||||
| 		// Start a transaction to apply the latest schema. | ||||
| 		tx, err := s.driver.GetDB().Begin() | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "failed to start transaction") | ||||
| 		} | ||||
| 		defer tx.Rollback() | ||||
| 		if err := s.execute(ctx, tx, string(bytes)); err != nil { | ||||
| 			return errors.Errorf("failed to execute SQL file %s, err %s", filePath, err) | ||||
| 		} | ||||
| 		if err := tx.Commit(); err != nil { | ||||
| 			return errors.Wrap(err, "failed to commit transaction") | ||||
| 		} | ||||
| 
 | ||||
| 		if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ | ||||
| 			Version: version.GetCurrentVersion(s.Profile.Mode), | ||||
| 			Version: schemaVersion, | ||||
| 		}); err != nil { | ||||
| 			return errors.Wrap(err, "failed to upsert migration history") | ||||
| 		} | ||||
|  | @ -95,52 +155,17 @@ func (s *Store) getMigrationBasePath() string { | |||
| 	if s.Profile.Mode == "prod" { | ||||
| 		mode = "prod" | ||||
| 	} | ||||
| 	return fmt.Sprintf("migration/%s/%s/", s.driver.Type(), mode) | ||||
| 	return fmt.Sprintf("migration/%s/%s/", s.Profile.Driver, mode) | ||||
| } | ||||
| 
 | ||||
| func (s *Store) getSeedBasePath() string { | ||||
| 	return fmt.Sprintf("seed/%s/", s.driver.Type()) | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	latestSchemaFileName = "LATEST__SCHEMA.sql" | ||||
| ) | ||||
| 
 | ||||
| func (s *Store) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { | ||||
| 	filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion)) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "failed to read migration files") | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Strings(filenames) | ||||
| 	migrationStmt := "" | ||||
| 	// Loop over all migration files and execute them in order. | ||||
| 	for _, filename := range filenames { | ||||
| 		buf, err := migrationFS.ReadFile(filename) | ||||
| 		if err != nil { | ||||
| 			return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) | ||||
| 		} | ||||
| 		stmt := string(buf) | ||||
| 		migrationStmt += stmt | ||||
| 		if err := s.execute(ctx, stmt); err != nil { | ||||
| 			return errors.Wrapf(err, "migrate error: %s", stmt) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Upsert the newest version to migration_history. | ||||
| 	version := minorVersion + ".0" | ||||
| 	if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ | ||||
| 		Version: version, | ||||
| 	}); err != nil { | ||||
| 		return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| 	return fmt.Sprintf("seed/%s/", s.Profile.Driver) | ||||
| } | ||||
| 
 | ||||
| func (s *Store) seed(ctx context.Context) error { | ||||
| 	// Only seed for SQLite. | ||||
| 	if s.driver.Type() != "sqlite" { | ||||
| 	if s.Profile.Driver != "sqlite" { | ||||
| 		slog.Warn("seed is only supported for SQLite") | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
|  | @ -149,49 +174,67 @@ func (s *Store) seed(ctx context.Context) error { | |||
| 		return errors.Wrap(err, "failed to read seed files") | ||||
| 	} | ||||
| 
 | ||||
| 	// Sort seed files by name. This is important to ensure that seed files are applied in order. | ||||
| 	sort.Strings(filenames) | ||||
| 	// Start a transaction to apply the seed files. | ||||
| 	tx, err := s.driver.GetDB().Begin() | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "failed to start transaction") | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 	// Loop over all seed files and execute them in order. | ||||
| 	for _, filename := range filenames { | ||||
| 		bytes, err := seedFS.ReadFile(filename) | ||||
| 		if err != nil { | ||||
| 			return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) | ||||
| 		} | ||||
| 		if err := s.execute(ctx, string(bytes)); err != nil { | ||||
| 		if err := s.execute(ctx, tx, string(bytes)); err != nil { | ||||
| 			return errors.Wrapf(err, "seed error: %s", filename) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // execute runs a single SQL statement within a transaction. | ||||
| func (s *Store) execute(ctx context.Context, stmt string) error { | ||||
| 	tx, err := s.driver.GetDB().Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 	if _, err := tx.ExecContext(ctx, stmt); err != nil { | ||||
| 		return errors.Wrap(err, "failed to execute statement") | ||||
| 	} | ||||
| 	return tx.Commit() | ||||
| } | ||||
| 
 | ||||
| func (s *Store) getMinorVersionList() []string { | ||||
| 	var minorDirRegexp = regexp.MustCompile(fmt.Sprintf(`^%s[0-9]+\.[0-9]+$`, s.getMigrationBasePath())) | ||||
| 	minorVersionList := []string{} | ||||
| 	if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if file.IsDir() && minorDirRegexp.MatchString(path) { | ||||
| 			minorVersionList = append(minorVersionList, file.Name()) | ||||
| 		} | ||||
| 
 | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		panic(err) | ||||
| func (s *Store) GetCurrentSchemaVersion() (string, error) { | ||||
| 	currentVersion := version.GetCurrentVersion(s.Profile.Mode) | ||||
| 	minorVersion := version.GetMinorVersion(currentVersion) | ||||
| 	filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion)) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrap(err, "failed to read migration files") | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Sort(version.SortVersion(minorVersionList)) | ||||
| 	return minorVersionList | ||||
| 	sort.Strings(filePaths) | ||||
| 	if len(filePaths) == 0 { | ||||
| 		return fmt.Sprintf("%s.0", minorVersion), nil | ||||
| 	} else { | ||||
| 		return s.getSchemaVersionOfMigrateScript(filePaths[len(filePaths)-1]) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *Store) getSchemaVersionOfMigrateScript(filePath string) (string, error) { | ||||
| 	// If the file is the latest schema file, return the current schema version. | ||||
| 	if strings.HasSuffix(filePath, LATEST_SCHEMA_FILE_NAME) { | ||||
| 		return s.GetCurrentSchemaVersion() | ||||
| 	} | ||||
| 
 | ||||
| 	normalizedPath := filepath.ToSlash(filePath) | ||||
| 	elements := strings.Split(normalizedPath, "/") | ||||
| 	if len(elements) < 2 { | ||||
| 		return "", errors.Errorf("invalid file path: %s", filePath) | ||||
| 	} | ||||
| 	minorVersion := elements[len(elements)-2] | ||||
| 	rawPatchVersion := strings.Split(elements[len(elements)-1], MIGRATE_FILE_NAME_SPLIT)[0] | ||||
| 	patchVersion, err := strconv.Atoi(rawPatchVersion) | ||||
| 	if err != nil { | ||||
| 		return "", errors.Wrapf(err, "failed to convert patch version to int: %s", rawPatchVersion) | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s.%d", minorVersion, patchVersion+1), nil | ||||
| } | ||||
| 
 | ||||
| // execute runs a single SQL statement within a transaction. | ||||
| func (s *Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error { | ||||
| 	if _, err := tx.ExecContext(ctx, stmt); err != nil { | ||||
| 		return errors.Wrap(err, "failed to execute statement") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
							
								
								
									
										17
									
								
								test/store/migrator_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								test/store/migrator_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| package teststore | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestGetCurrentSchemaVersion(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	ts := NewTestingStore(ctx, t) | ||||
| 
 | ||||
| 	currentSchemaVersion, err := ts.GetCurrentSchemaVersion() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "0.22.4", currentSchemaVersion) | ||||
| } | ||||
|  | @ -32,7 +32,7 @@ func GetTestingProfile(t *testing.T) *profile.Profile { | |||
| 
 | ||||
| 	// Get a temporary directory for the test data. | ||||
| 	dir := t.TempDir() | ||||
| 	mode := "dev" | ||||
| 	mode := "prod" | ||||
| 	port := getUnusedPort() | ||||
| 	driver := getDriverFromEnv() | ||||
| 	dsn := os.Getenv("DSN") | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue