feat: add setup cmd (#1418)

This command can be used for automatization of initial application's setup
This commit is contained in:
Dmitry Shemin 2023-03-27 20:22:49 +07:00 committed by GitHub
parent 6b703c4678
commit e7ee181a91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 321 additions and 0 deletions

View file

@ -7,12 +7,16 @@ import (
"os"
"os/signal"
"syscall"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/usememos/memos/server"
_profile "github.com/usememos/memos/server/profile"
"github.com/usememos/memos/setup"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
)
const (
@ -69,6 +73,40 @@ var (
<-ctx.Done()
},
}
setupCmd = &cobra.Command{
Use: "setup",
Short: "Make initial setup for memos",
Run: func(cmd *cobra.Command, _ []string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
hostUsername, err := cmd.Flags().GetString(setupCmdFlagHostUsername)
if err != nil {
fmt.Printf("failed to get owner username, error: %+v\n", err)
return
}
hostPassword, err := cmd.Flags().GetString(setupCmdFlagHostPassword)
if err != nil {
fmt.Printf("failed to get owner password, error: %+v\n", err)
return
}
db := db.NewDB(profile)
if err := db.Open(ctx); err != nil {
fmt.Printf("failed to open db, error: %+v\n", err)
return
}
st := store.New(db.DBInstance, profile)
if err := setup.Execute(ctx, st, hostUsername, hostPassword); err != nil {
fmt.Printf("failed to setup, error: %+v\n", err)
return
}
},
}
)
func Execute() error {
@ -98,6 +136,11 @@ func init() {
viper.SetDefault("mode", "demo")
viper.SetDefault("port", 8081)
viper.SetEnvPrefix("memos")
setupCmd.Flags().String(setupCmdFlagHostUsername, "", "Owner username")
setupCmd.Flags().String(setupCmdFlagHostPassword, "", "Owner password")
rootCmd.AddCommand(setupCmd)
}
func initConfig() {
@ -117,3 +160,8 @@ func initConfig() {
println("version:", profile.Version)
println("---")
}
const (
setupCmdFlagHostUsername = "host-username"
setupCmdFlagHostPassword = "host-password"
)

1
go.mod
View file

@ -62,6 +62,7 @@ require (
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.1 // indirect

1
go.sum
View file

@ -244,6 +244,7 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU=
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=

90
setup/setup.go Normal file
View file

@ -0,0 +1,90 @@
package setup
import (
"context"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
)
func Execute(
ctx context.Context,
store store,
hostUsername, hostPassword string,
) error {
s := setupService{store: store}
return s.Setup(ctx, hostUsername, hostPassword)
}
type store interface {
FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error)
CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error)
}
type setupService struct {
store store
}
func (s setupService) Setup(
ctx context.Context,
hostUsername, hostPassword string,
) error {
if err := s.makeSureHostUserNotExists(ctx); err != nil {
return err
}
if err := s.createUser(ctx, hostUsername, hostPassword); err != nil {
return fmt.Errorf("create user: %w", err)
}
return nil
}
func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
hostUserType := api.Host
existedHostUsers, err := s.store.FindUserList(ctx, &api.UserFind{
Role: &hostUserType,
})
if err != nil {
return fmt.Errorf("find user list: %w", err)
}
if len(existedHostUsers) != 0 {
return errors.New("host user already exists")
}
return nil
}
func (s setupService) createUser(
ctx context.Context,
hostUsername, hostPassword string,
) error {
userCreate := &api.UserCreate{
Username: hostUsername,
// The new signup user should be normal user by default.
Role: api.Host,
Nickname: hostUsername,
Password: hostPassword,
OpenID: common.GenUUID(),
}
if err := userCreate.Validate(); err != nil {
return fmt.Errorf("validate: %w", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
userCreate.PasswordHash = string(passwordHash)
if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
return fmt.Errorf("create user: %w", err)
}
return nil
}

181
setup/setup_test.go Normal file
View file

@ -0,0 +1,181 @@
package setup
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/usememos/memos/api"
)
func TestSetupService_makeSureHostUserNotExists(t *testing.T) {
cc := map[string]struct {
setupStore func(*storeMock)
expectedErr string
}{
"failed to get list": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return(nil, errors.New("fake error"))
},
expectedErr: "find user list: fake error",
},
"success, not empty": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return([]*api.User{
{},
}, nil)
},
expectedErr: "host user already exists",
},
"success, empty": {
setupStore: func(m *storeMock) {
hostUserType := api.Host
m.
On("FindUserList", mock.Anything, &api.UserFind{
Role: &hostUserType,
}).
Return(nil, nil)
},
},
}
for n, c := range cc {
c := c
t.Run(n, func(t *testing.T) {
sm := newStoreMock(t)
if c.setupStore != nil {
c.setupStore(sm)
}
srv := setupService{store: sm}
err := srv.makeSureHostUserNotExists(context.Background())
if c.expectedErr == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, c.expectedErr)
}
})
}
}
func TestSetupService_createUser(t *testing.T) {
expectedCreated := &api.UserCreate{
Username: "demohero",
Role: api.Host,
Nickname: "demohero",
Password: "123456",
}
userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool {
return arg.Username == expectedCreated.Username &&
arg.Role == expectedCreated.Role &&
arg.Nickname == expectedCreated.Nickname &&
arg.Password == expectedCreated.Password &&
arg.PasswordHash != ""
})
cc := map[string]struct {
setupStore func(*storeMock)
hostUsername, hostPassword string
expectedErr string
}{
`username == "", password == ""`: {
expectedErr: "validate: username is too short, minimum length is 3",
},
`username == "", password != ""`: {
hostPassword: expectedCreated.Password,
expectedErr: "validate: username is too short, minimum length is 3",
},
`username != "", password == ""`: {
hostUsername: expectedCreated.Username,
expectedErr: "validate: password is too short, minimum length is 6",
},
"failed to create": {
setupStore: func(m *storeMock) {
m.
On("CreateUser", mock.Anything, userCreateMatcher).
Return(nil, errors.New("fake error"))
},
hostUsername: expectedCreated.Username,
hostPassword: expectedCreated.Password,
expectedErr: "create user: fake error",
},
"success": {
setupStore: func(m *storeMock) {
m.
On("CreateUser", mock.Anything, userCreateMatcher).
Return(nil, nil)
},
hostUsername: expectedCreated.Username,
hostPassword: expectedCreated.Password,
},
}
for n, c := range cc {
c := c
t.Run(n, func(t *testing.T) {
sm := newStoreMock(t)
if c.setupStore != nil {
c.setupStore(sm)
}
srv := setupService{store: sm}
err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword)
if c.expectedErr == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, c.expectedErr)
}
})
}
}
type storeMock struct {
mock.Mock
}
func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
ret := m.Called(ctx, find)
var uu []*api.User
ret1 := ret.Get(0)
if ret1 != nil {
uu = ret1.([]*api.User)
}
return uu, ret.Error(1)
}
func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
ret := m.Called(ctx, create)
var u *api.User
ret1 := ret.Get(0)
if ret1 != nil {
u = ret1.(*api.User)
}
return u, ret.Error(1)
}
func newStoreMock(t *testing.T) *storeMock {
m := &storeMock{}
m.Mock.Test(t)
t.Cleanup(func() { m.AssertExpectations(t) })
return m
}