From e7ee181a91e305f047b5da464286215c142d54d7 Mon Sep 17 00:00:00 2001 From: Dmitry Shemin Date: Mon, 27 Mar 2023 20:22:49 +0700 Subject: [PATCH] feat: add setup cmd (#1418) This command can be used for automatization of initial application's setup --- cmd/memos.go | 48 ++++++++++++ go.mod | 1 + go.sum | 1 + setup/setup.go | 90 ++++++++++++++++++++++ setup/setup_test.go | 181 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 321 insertions(+) create mode 100644 setup/setup.go create mode 100644 setup/setup_test.go diff --git a/cmd/memos.go b/cmd/memos.go index 9b97670f..81ab5ab4 100644 --- a/cmd/memos.go +++ b/cmd/memos.go @@ -7,12 +7,16 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/usememos/memos/server" _profile "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/setup" + "github.com/usememos/memos/store" + "github.com/usememos/memos/store/db" ) const ( @@ -69,6 +73,40 @@ var ( <-ctx.Done() }, } + + setupCmd = &cobra.Command{ + Use: "setup", + Short: "Make initial setup for memos", + Run: func(cmd *cobra.Command, _ []string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + hostUsername, err := cmd.Flags().GetString(setupCmdFlagHostUsername) + if err != nil { + fmt.Printf("failed to get owner username, error: %+v\n", err) + return + } + + hostPassword, err := cmd.Flags().GetString(setupCmdFlagHostPassword) + if err != nil { + fmt.Printf("failed to get owner password, error: %+v\n", err) + return + } + + db := db.NewDB(profile) + if err := db.Open(ctx); err != nil { + fmt.Printf("failed to open db, error: %+v\n", err) + return + } + + st := store.New(db.DBInstance, profile) + + if err := setup.Execute(ctx, st, hostUsername, hostPassword); err != nil { + fmt.Printf("failed to setup, error: %+v\n", err) + return + } + }, + } ) func Execute() error { @@ -98,6 +136,11 @@ func init() { viper.SetDefault("mode", "demo") viper.SetDefault("port", 8081) viper.SetEnvPrefix("memos") + + setupCmd.Flags().String(setupCmdFlagHostUsername, "", "Owner username") + setupCmd.Flags().String(setupCmdFlagHostPassword, "", "Owner password") + + rootCmd.AddCommand(setupCmd) } func initConfig() { @@ -117,3 +160,8 @@ func initConfig() { println("version:", profile.Version) println("---") } + +const ( + setupCmdFlagHostUsername = "host-username" + setupCmdFlagHostPassword = "host-password" +) diff --git a/go.mod b/go.mod index 1b3351b9..2df8995d 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect diff --git a/go.sum b/go.sum index 3c6913af..64caa9ed 100644 --- a/go.sum +++ b/go.sum @@ -244,6 +244,7 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/setup/setup.go b/setup/setup.go new file mode 100644 index 00000000..b542983e --- /dev/null +++ b/setup/setup.go @@ -0,0 +1,90 @@ +package setup + +import ( + "context" + "errors" + "fmt" + + "golang.org/x/crypto/bcrypt" + + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" +) + +func Execute( + ctx context.Context, + store store, + hostUsername, hostPassword string, +) error { + s := setupService{store: store} + return s.Setup(ctx, hostUsername, hostPassword) +} + +type store interface { + FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) + CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) +} + +type setupService struct { + store store +} + +func (s setupService) Setup( + ctx context.Context, + hostUsername, hostPassword string, +) error { + if err := s.makeSureHostUserNotExists(ctx); err != nil { + return err + } + + if err := s.createUser(ctx, hostUsername, hostPassword); err != nil { + return fmt.Errorf("create user: %w", err) + } + return nil +} + +func (s setupService) makeSureHostUserNotExists(ctx context.Context) error { + hostUserType := api.Host + existedHostUsers, err := s.store.FindUserList(ctx, &api.UserFind{ + Role: &hostUserType, + }) + if err != nil { + return fmt.Errorf("find user list: %w", err) + } + + if len(existedHostUsers) != 0 { + return errors.New("host user already exists") + } + + return nil +} + +func (s setupService) createUser( + ctx context.Context, + hostUsername, hostPassword string, +) error { + userCreate := &api.UserCreate{ + Username: hostUsername, + // The new signup user should be normal user by default. + Role: api.Host, + Nickname: hostUsername, + Password: hostPassword, + OpenID: common.GenUUID(), + } + + if err := userCreate.Validate(); err != nil { + return fmt.Errorf("validate: %w", err) + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + + userCreate.PasswordHash = string(passwordHash) + if _, err := s.store.CreateUser(ctx, userCreate); err != nil { + return fmt.Errorf("create user: %w", err) + } + + return nil +} diff --git a/setup/setup_test.go b/setup/setup_test.go new file mode 100644 index 00000000..15e61514 --- /dev/null +++ b/setup/setup_test.go @@ -0,0 +1,181 @@ +package setup + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/usememos/memos/api" +) + +func TestSetupService_makeSureHostUserNotExists(t *testing.T) { + cc := map[string]struct { + setupStore func(*storeMock) + expectedErr string + }{ + "failed to get list": { + setupStore: func(m *storeMock) { + hostUserType := api.Host + m. + On("FindUserList", mock.Anything, &api.UserFind{ + Role: &hostUserType, + }). + Return(nil, errors.New("fake error")) + }, + expectedErr: "find user list: fake error", + }, + "success, not empty": { + setupStore: func(m *storeMock) { + hostUserType := api.Host + m. + On("FindUserList", mock.Anything, &api.UserFind{ + Role: &hostUserType, + }). + Return([]*api.User{ + {}, + }, nil) + }, + expectedErr: "host user already exists", + }, + "success, empty": { + setupStore: func(m *storeMock) { + hostUserType := api.Host + m. + On("FindUserList", mock.Anything, &api.UserFind{ + Role: &hostUserType, + }). + Return(nil, nil) + }, + }, + } + + for n, c := range cc { + c := c + t.Run(n, func(t *testing.T) { + sm := newStoreMock(t) + if c.setupStore != nil { + c.setupStore(sm) + } + + srv := setupService{store: sm} + err := srv.makeSureHostUserNotExists(context.Background()) + if c.expectedErr == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, c.expectedErr) + } + }) + } +} + +func TestSetupService_createUser(t *testing.T) { + expectedCreated := &api.UserCreate{ + Username: "demohero", + Role: api.Host, + Nickname: "demohero", + Password: "123456", + } + + userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool { + return arg.Username == expectedCreated.Username && + arg.Role == expectedCreated.Role && + arg.Nickname == expectedCreated.Nickname && + arg.Password == expectedCreated.Password && + arg.PasswordHash != "" + }) + + cc := map[string]struct { + setupStore func(*storeMock) + hostUsername, hostPassword string + expectedErr string + }{ + `username == "", password == ""`: { + expectedErr: "validate: username is too short, minimum length is 3", + }, + `username == "", password != ""`: { + hostPassword: expectedCreated.Password, + expectedErr: "validate: username is too short, minimum length is 3", + }, + `username != "", password == ""`: { + hostUsername: expectedCreated.Username, + expectedErr: "validate: password is too short, minimum length is 6", + }, + "failed to create": { + setupStore: func(m *storeMock) { + m. + On("CreateUser", mock.Anything, userCreateMatcher). + Return(nil, errors.New("fake error")) + }, + hostUsername: expectedCreated.Username, + hostPassword: expectedCreated.Password, + expectedErr: "create user: fake error", + }, + "success": { + setupStore: func(m *storeMock) { + m. + On("CreateUser", mock.Anything, userCreateMatcher). + Return(nil, nil) + }, + hostUsername: expectedCreated.Username, + hostPassword: expectedCreated.Password, + }, + } + + for n, c := range cc { + c := c + t.Run(n, func(t *testing.T) { + sm := newStoreMock(t) + if c.setupStore != nil { + c.setupStore(sm) + } + + srv := setupService{store: sm} + err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword) + if c.expectedErr == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, c.expectedErr) + } + }) + } +} + +type storeMock struct { + mock.Mock +} + +func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) { + ret := m.Called(ctx, find) + + var uu []*api.User + ret1 := ret.Get(0) + if ret1 != nil { + uu = ret1.([]*api.User) + } + + return uu, ret.Error(1) +} + +func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) { + ret := m.Called(ctx, create) + + var u *api.User + ret1 := ret.Get(0) + if ret1 != nil { + u = ret1.(*api.User) + } + + return u, ret.Error(1) +} + +func newStoreMock(t *testing.T) *storeMock { + m := &storeMock{} + m.Mock.Test(t) + + t.Cleanup(func() { m.AssertExpectations(t) }) + + return m +}