diff --git a/cmd/memos.go b/cmd/memos.go index 81ab5ab49..a621a9f07 100644 --- a/cmd/memos.go +++ b/cmd/memos.go @@ -99,9 +99,8 @@ var ( return } - st := store.New(db.DBInstance, profile) - - if err := setup.Execute(ctx, st, hostUsername, hostPassword); err != nil { + store := store.New(db.DBInstance, profile) + if err := setup.Execute(ctx, store, hostUsername, hostPassword); err != nil { fmt.Printf("failed to setup, error: %+v\n", err) return } diff --git a/setup/setup_test.go b/setup/setup_test.go index 15e615143..fd6af3838 100644 --- a/setup/setup_test.go +++ b/setup/setup_test.go @@ -11,7 +11,7 @@ import ( "github.com/usememos/memos/api" ) -func TestSetupService_makeSureHostUserNotExists(t *testing.T) { +func TestSetupServiceMakeSureHostUserNotExists(t *testing.T) { cc := map[string]struct { setupStore func(*storeMock) expectedErr string @@ -71,7 +71,7 @@ func TestSetupService_makeSureHostUserNotExists(t *testing.T) { } } -func TestSetupService_createUser(t *testing.T) { +func TestSetupServiceCreateUser(t *testing.T) { expectedCreated := &api.UserCreate{ Username: "demohero", Role: api.Host, @@ -150,13 +150,13 @@ type storeMock struct { func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) { ret := m.Called(ctx, find) - var uu []*api.User + var u []*api.User ret1 := ret.Get(0) if ret1 != nil { - uu = ret1.([]*api.User) + u = ret1.([]*api.User) } - return uu, ret.Error(1) + return u, ret.Error(1) } func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) { diff --git a/test/store/store.go b/test/store/store.go new file mode 100644 index 000000000..17a96500d --- /dev/null +++ b/test/store/store.go @@ -0,0 +1,25 @@ +package store_test + +import ( + "context" + "fmt" + "testing" + + "github.com/usememos/memos/store" + "github.com/usememos/memos/store/db" + "github.com/usememos/memos/test" + + // sqlite3 driver. + _ "github.com/mattn/go-sqlite3" +) + +func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { + profile := test.GetTestingProfile(t) + db := db.NewDB(profile) + if err := db.Open(ctx); err != nil { + fmt.Printf("failed to open db, error: %+v\n", err) + } + + store := store.New(db.DBInstance, profile) + return store +} diff --git a/test/store/user_test.go b/test/store/user_test.go new file mode 100644 index 000000000..4dd74baea --- /dev/null +++ b/test/store/user_test.go @@ -0,0 +1,47 @@ +package store_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" + "golang.org/x/crypto/bcrypt" +) + +func TestUserStore(t *testing.T) { + ctx := context.Background() + store := NewTestingStore(ctx, t) + userCreate := &api.UserCreate{ + Username: "test", + Role: api.Host, + Email: "test@test.com", + Nickname: "test_nickname", + Password: "test_password", + OpenID: "test_open_id", + } + passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost) + require.NoError(t, err) + userCreate.PasswordHash = string(passwordHash) + user, err := store.CreateUser(ctx, userCreate) + require.NoError(t, err) + users, err := store.FindUserList(ctx, &api.UserFind{}) + require.NoError(t, err) + require.Equal(t, 1, len(users)) + require.Equal(t, user, users[0]) + userPatchNickname := "test_nickname_2" + userPatch := &api.UserPatch{ + ID: user.ID, + Nickname: &userPatchNickname, + } + user, err = store.PatchUser(ctx, userPatch) + require.NoError(t, err) + require.Equal(t, userPatchNickname, user.Nickname) + err = store.DeleteUser(ctx, &api.UserDelete{ + ID: user.ID, + }) + require.NoError(t, err) + users, err = store.FindUserList(ctx, &api.UserFind{}) + require.NoError(t, err) + require.Equal(t, 0, len(users)) +} diff --git a/test/test.go b/test/test.go new file mode 100644 index 000000000..384be8d26 --- /dev/null +++ b/test/test.go @@ -0,0 +1,22 @@ +package test + +import ( + "fmt" + "testing" + + "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/server/version" +) + +func GetTestingProfile(t *testing.T) *profile.Profile { + // Get a temporary directory for the test data. + dir := t.TempDir() + mode := "prod" + return &profile.Profile{ + Mode: mode, + Port: 8082, + Data: dir, + DSN: fmt.Sprintf("%s/memos_%s.db", dir, mode), + Version: version.GetCurrentVersion(mode), + } +}