diff --git a/cmd/copydb.go b/cmd/copydb.go new file mode 100644 index 00000000..ebebad8f --- /dev/null +++ b/cmd/copydb.go @@ -0,0 +1,383 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + _profile "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/store" + "github.com/usememos/memos/store/db" +) + +var ( + copydbCmdFlagFrom = "from" + copydbCmd = &cobra.Command{ + Use: "copydb", // `copydb` is a shortened for 'copy database' + Short: "Copy data between db drivers", + Run: func(cmd *cobra.Command, _ []string) { + s, err := cmd.Flags().GetString(copydbCmdFlagFrom) + if err != nil { + println("fail to get from driver DSN") + println(err) + return + } + ss := strings.Split(s, "://") + if len(ss) != 2 { + println("fail to parse from driver DSN, should be like 'sqlite://memos_prod.db' or 'mysql://user:pass@tcp(host)/memos'") + return + } + + fromProfile := &_profile.Profile{Driver: ss[0], DSN: ss[1]} + + err = copydb(fromProfile, profile) + if err != nil { + fmt.Printf("fail to copydb: %s\n", err) + return + } + + println("done") + }, + } +) + +func init() { + copydbCmd.Flags().String(copydbCmdFlagFrom, "sqlite://memos_prod.db", "From driver DSN") + + rootCmd.AddCommand(copydbCmd) +} + +func copydb(fromProfile, toProfile *_profile.Profile) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + toDriver, err := db.NewDBDriver(toProfile) + if err != nil { + return errors.Wrap(err, "fail to create `to` driver") + } + + if err := toDriver.Migrate(ctx); err != nil { + return errors.Wrap(err, "fail to migrate db") + } + + fromDriver, err := db.NewDBDriver(fromProfile) + if err != nil { + return errors.Wrap(err, "fail to create `from` driver") + } + + // Register here if any table is added + copyMap := map[string]func(context.Context, store.Driver, store.Driver) error{ + "activity": copyActivity, + "idp": copyIdp, + "memo": copyMemo, + "memo_organizer": copyMemoOrganizer, + "memo_relation": copyMemoRelation, + "resource": copyResource, + "storage": copyStorage, + "system_setting": copySystemSettings, + "tag": copyTag, + "user": copyUser, + "user_setting": copyUserSettings, + } + + toDb := toDriver.GetDB() + for table := range copyMap { + println("Checking " + table + "...") + var cnt int + err := toDb.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table).Scan(&cnt) + if err != nil { + return errors.Wrapf(err, "fail to check '%s'", table) + } + if cnt > 0 { + return errors.Errorf("table '%s' is not empty", table) + } + } + + for _, f := range copyMap { + err = f(ctx, fromDriver, toDriver) + if err != nil { + return errors.Wrap(err, "fail to copy data") + } + } + + return nil +} + +func copyActivity(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying Activity...") + list, err := fromDriver.ListActivity(ctx, &store.FindActivity{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateActivity(ctx, &store.Activity{ + ID: item.ID, + CreatorID: item.CreatorID, + CreatedTs: item.CreatedTs, + Level: item.Level, + Type: item.Type, + Payload: item.Payload, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyIdp(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying IdentityProvider...") + list, err := fromDriver.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateIdentityProvider(ctx, &store.IdentityProvider{ + ID: item.ID, + Name: item.Name, + Type: item.Type, + IdentifierFilter: item.IdentifierFilter, + Config: item.Config, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyMemo(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying Memo...") + list, err := fromDriver.ListMemos(ctx, &store.FindMemo{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateMemo(ctx, &store.Memo{ + ID: item.ID, + CreatorID: item.CreatorID, + CreatedTs: item.CreatedTs, + UpdatedTs: item.UpdatedTs, + RowStatus: item.RowStatus, + Content: item.Content, + Visibility: item.Visibility, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyMemoOrganizer(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying MemoOrganizer...") + list, err := fromDriver.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{ + MemoID: item.MemoID, + UserID: item.UserID, + Pinned: item.Pinned, + }) + if err != nil { + return err + } + } + println("\tDONE") + return nil +} + +func copyMemoRelation(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying MemoRelation...") + list, err := fromDriver.ListMemoRelations(ctx, &store.FindMemoRelation{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: item.MemoID, + RelatedMemoID: item.RelatedMemoID, + Type: item.Type, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyResource(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying Resource...") + list, err := fromDriver.ListResources(ctx, &store.FindResource{GetBlob: true}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateResource(ctx, &store.Resource{ + ID: item.ID, + CreatorID: item.CreatorID, + CreatedTs: item.CreatedTs, + UpdatedTs: item.UpdatedTs, + Filename: item.Filename, + Blob: item.Blob, + ExternalLink: item.ExternalLink, + Type: item.Type, + Size: item.Size, + InternalPath: item.InternalPath, + MemoID: item.MemoID, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyStorage(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying Storage...") + list, err := fromDriver.ListStorages(ctx, &store.FindStorage{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateStorage(ctx, &store.Storage{ + ID: item.ID, + Name: item.Name, + Type: item.Type, + Config: item.Config, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copySystemSettings(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying SystemSettings...") + list, err := fromDriver.ListSystemSettings(ctx, &store.FindSystemSetting{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.UpsertSystemSetting(ctx, &store.SystemSetting{ + Name: item.Name, + Value: item.Value, + Description: item.Description, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyTag(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying Tag...") + list, err := fromDriver.ListTags(ctx, &store.FindTag{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.UpsertTag(ctx, &store.Tag{ + Name: item.Name, + CreatorID: item.CreatorID, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyUser(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying User...") + list, err := fromDriver.ListUsers(ctx, &store.FindUser{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.CreateUser(ctx, &store.User{ + ID: item.ID, + CreatedTs: item.CreatedTs, + UpdatedTs: item.UpdatedTs, + RowStatus: item.RowStatus, + Username: item.Username, + Role: item.Role, + Email: item.Email, + Nickname: item.Nickname, + PasswordHash: item.PasswordHash, + AvatarURL: item.AvatarURL, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +} + +func copyUserSettings(ctx context.Context, fromDriver, toDriver store.Driver) error { + println("Copying UserSettings...") + list, err := fromDriver.ListUserSettings(ctx, &store.FindUserSetting{}) + if err != nil { + return err + } + + fmt.Printf("\tTotal %d records\n", len(list)) + for _, item := range list { + _, err := toDriver.UpsertUserSetting(ctx, &store.UserSetting{ + Key: item.Key, + Value: item.Value, + UserID: item.UserID, + }) + if err != nil { + return err + } + } + + println("\tDONE") + return nil +}