package cmd

import (
	"context"
	"fmt"
	"strings"

	"github.com/pkg/errors"
	"github.com/spf13/cobra"

	_profile "github.com/usememos/memos/server/profile"
	"github.com/usememos/memos/store"
	"github.com/usememos/memos/store/db"
)

var (
	copydbCmdFlagFrom = "from"
	copydbCmd         = &cobra.Command{
		Use:   "copydb", // `copydb` is a shortened for 'copy database'
		Short: "Copy data between db drivers",
		Run: func(cmd *cobra.Command, _ []string) {
			s, err := cmd.Flags().GetString(copydbCmdFlagFrom)
			if err != nil {
				println("fail to get from driver DSN")
				println(err)
				return
			}
			ss := strings.Split(s, "://")
			if len(ss) != 2 {
				println("fail to parse from driver DSN, should be like 'sqlite://memos_prod.db' or 'mysql://user:pass@tcp(host)/memos'")
				return
			}

			fromProfile := &_profile.Profile{Driver: ss[0], DSN: ss[1]}

			err = copydb(fromProfile, profile)
			if err != nil {
				fmt.Printf("fail to copydb: %s\n", err)
				return
			}

			println("done")
		},
	}
)

func init() {
	copydbCmd.Flags().String(copydbCmdFlagFrom, "sqlite://memos_prod.db", "From driver DSN")

	rootCmd.AddCommand(copydbCmd)
}

func copydb(fromProfile, toProfile *_profile.Profile) error {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	toDriver, err := db.NewDBDriver(toProfile)
	if err != nil {
		return errors.Wrap(err, "fail to create `to` driver")
	}

	if err := toDriver.Migrate(ctx); err != nil {
		return errors.Wrap(err, "fail to migrate db")
	}

	fromDriver, err := db.NewDBDriver(fromProfile)
	if err != nil {
		return errors.Wrap(err, "fail to create `from` driver")
	}

	// Register here if any table is added
	copyMap := map[string]func(context.Context, store.Driver, store.Driver) error{
		"activity":       copyActivity,
		"idp":            copyIdp,
		"memo":           copyMemo,
		"memo_organizer": copyMemoOrganizer,
		"memo_relation":  copyMemoRelation,
		"resource":       copyResource,
		"storage":        copyStorage,
		"system_setting": copySystemSettings,
		"tag":            copyTag,
		"user":           copyUser,
		"user_setting":   copyUserSettings,
	}

	toDb := toDriver.GetDB()
	for table := range copyMap {
		println("Checking " + table + "...")
		var cnt int
		err := toDb.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table).Scan(&cnt)
		if err != nil {
			return errors.Wrapf(err, "fail to check '%s'", table)
		}
		if cnt > 0 {
			return errors.Errorf("table '%s' is not empty", table)
		}
	}

	for _, f := range copyMap {
		err = f(ctx, fromDriver, toDriver)
		if err != nil {
			return errors.Wrap(err, "fail to copy data")
		}
	}

	return nil
}

func copyActivity(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying Activity...")
	list, err := fromDriver.ListActivity(ctx, &store.FindActivity{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateActivity(ctx, &store.Activity{
			ID:        item.ID,
			CreatorID: item.CreatorID,
			CreatedTs: item.CreatedTs,
			Level:     item.Level,
			Type:      item.Type,
			Payload:   item.Payload,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyIdp(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying IdentityProvider...")
	list, err := fromDriver.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateIdentityProvider(ctx, &store.IdentityProvider{
			ID:               item.ID,
			Name:             item.Name,
			Type:             item.Type,
			IdentifierFilter: item.IdentifierFilter,
			Config:           item.Config,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyMemo(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying Memo...")
	list, err := fromDriver.ListMemos(ctx, &store.FindMemo{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateMemo(ctx, &store.Memo{
			ID:         item.ID,
			CreatorID:  item.CreatorID,
			CreatedTs:  item.CreatedTs,
			UpdatedTs:  item.UpdatedTs,
			RowStatus:  item.RowStatus,
			Content:    item.Content,
			Visibility: item.Visibility,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyMemoOrganizer(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying MemoOrganizer...")
	list, err := fromDriver.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
			MemoID: item.MemoID,
			UserID: item.UserID,
			Pinned: item.Pinned,
		})
		if err != nil {
			return err
		}
	}
	println("\tDONE")
	return nil
}

func copyMemoRelation(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying MemoRelation...")
	list, err := fromDriver.ListMemoRelations(ctx, &store.FindMemoRelation{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.UpsertMemoRelation(ctx, &store.MemoRelation{
			MemoID:        item.MemoID,
			RelatedMemoID: item.RelatedMemoID,
			Type:          item.Type,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyResource(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying Resource...")
	list, err := fromDriver.ListResources(ctx, &store.FindResource{GetBlob: true})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateResource(ctx, &store.Resource{
			ID:           item.ID,
			CreatorID:    item.CreatorID,
			CreatedTs:    item.CreatedTs,
			UpdatedTs:    item.UpdatedTs,
			Filename:     item.Filename,
			Blob:         item.Blob,
			ExternalLink: item.ExternalLink,
			Type:         item.Type,
			Size:         item.Size,
			InternalPath: item.InternalPath,
			MemoID:       item.MemoID,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyStorage(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying Storage...")
	list, err := fromDriver.ListStorages(ctx, &store.FindStorage{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateStorage(ctx, &store.Storage{
			ID:     item.ID,
			Name:   item.Name,
			Type:   item.Type,
			Config: item.Config,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copySystemSettings(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying SystemSettings...")
	list, err := fromDriver.ListSystemSettings(ctx, &store.FindSystemSetting{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.UpsertSystemSetting(ctx, &store.SystemSetting{
			Name:        item.Name,
			Value:       item.Value,
			Description: item.Description,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyTag(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying Tag...")
	list, err := fromDriver.ListTags(ctx, &store.FindTag{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.UpsertTag(ctx, &store.Tag{
			Name:      item.Name,
			CreatorID: item.CreatorID,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyUser(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying User...")
	list, err := fromDriver.ListUsers(ctx, &store.FindUser{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.CreateUser(ctx, &store.User{
			ID:           item.ID,
			CreatedTs:    item.CreatedTs,
			UpdatedTs:    item.UpdatedTs,
			RowStatus:    item.RowStatus,
			Username:     item.Username,
			Role:         item.Role,
			Email:        item.Email,
			Nickname:     item.Nickname,
			PasswordHash: item.PasswordHash,
			AvatarURL:    item.AvatarURL,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}

func copyUserSettings(ctx context.Context, fromDriver, toDriver store.Driver) error {
	println("Copying UserSettings...")
	list, err := fromDriver.ListUserSettings(ctx, &store.FindUserSetting{})
	if err != nil {
		return err
	}

	fmt.Printf("\tTotal %d records\n", len(list))
	for _, item := range list {
		_, err := toDriver.UpsertUserSetting(ctx, &store.UserSetting{
			Key:    item.Key,
			Value:  item.Value,
			UserID: item.UserID,
		})
		if err != nil {
			return err
		}
	}

	println("\tDONE")
	return nil
}