From 5c5199920e5bbd86a57051b3dadfa1d159f7dbff Mon Sep 17 00:00:00 2001 From: boojack Date: Sat, 13 May 2023 22:25:15 +0800 Subject: [PATCH] chore: seed data for new user (#1655) --- store/user.go | 37 +++++++++++++++++++++++++++++++ test/server/memo_relation_test.go | 4 ++-- test/server/memo_test.go | 6 ++--- test/store/memo_test.go | 4 ++-- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/store/user.go b/store/user.go index ac59ef11..a4843e9e 100644 --- a/store/user.go +++ b/store/user.go @@ -10,6 +10,37 @@ import ( "github.com/usememos/memos/common" ) +func (s *Store) SeedDataForNewUser(ctx context.Context, user *api.User) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + // Create a memo for the user. + _, err = createMemoRaw(ctx, tx, &api.MemoCreate{ + CreatorID: user.ID, + Content: "#inbox Welcome to Memos!", + Visibility: api.Private, + }) + if err != nil { + return err + } + _, err = upsertTag(ctx, tx, &api.TagUpsert{ + CreatorID: user.ID, + Name: "inbox", + }) + if err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } + + return nil +} + // userRaw is the store model for an User. // Fields have exactly the same meanings as User. type userRaw struct { @@ -63,6 +94,7 @@ func (s *Store) ComposeMemoCreator(ctx context.Context, memo *api.Memo) error { } return nil } + func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { @@ -81,6 +113,11 @@ func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.Us s.userCache.Store(userRaw.ID, userRaw) user := userRaw.toUser() + + if err := s.SeedDataForNewUser(ctx, user); err != nil { + return nil, err + } + return user, nil } diff --git a/test/server/memo_relation_test.go b/test/server/memo_relation_test.go index 8be8a000..fac2bb0d 100644 --- a/test/server/memo_relation_test.go +++ b/test/server/memo_relation_test.go @@ -27,7 +27,7 @@ func TestMemoRelationServer(t *testing.T) { require.Equal(t, signup.Username, user.Username) memoList, err := s.getMemoList() require.NoError(t, err) - require.Len(t, memoList, 0) + require.Len(t, memoList, 1) memo, err := s.postMemoCreate(&api.MemoCreate{ Content: "test memo", }) @@ -46,7 +46,7 @@ func TestMemoRelationServer(t *testing.T) { require.Equal(t, "test memo2", memo2.Content) memoList, err = s.getMemoList() require.NoError(t, err) - require.Len(t, memoList, 2) + require.Len(t, memoList, 3) require.Len(t, memo2.RelationList, 1) err = s.deleteMemoRelation(memo2.ID, memo.ID, api.MemoRelationReference) require.NoError(t, err) diff --git a/test/server/memo_test.go b/test/server/memo_test.go index 87e55104..bfaf3728 100644 --- a/test/server/memo_test.go +++ b/test/server/memo_test.go @@ -27,7 +27,7 @@ func TestMemoServer(t *testing.T) { require.Equal(t, signup.Username, user.Username) memoList, err := s.getMemoList() require.NoError(t, err) - require.Len(t, memoList, 0) + require.Len(t, memoList, 1) memo, err := s.postMemoCreate(&api.MemoCreate{ Content: "test memo", }) @@ -35,7 +35,7 @@ func TestMemoServer(t *testing.T) { require.Equal(t, "test memo", memo.Content) memoList, err = s.getMemoList() require.NoError(t, err) - require.Len(t, memoList, 1) + require.Len(t, memoList, 2) updatedContent := "updated memo" memo, err = s.patchMemo(&api.MemoPatch{ ID: memo.ID, @@ -63,7 +63,7 @@ func TestMemoServer(t *testing.T) { require.NoError(t, err) memoList, err = s.getMemoList() require.NoError(t, err) - require.Len(t, memoList, 0) + require.Len(t, memoList, 1) } func (s *TestingServer) getMemo(memoID int) (*api.Memo, error) { diff --git a/test/store/memo_test.go b/test/store/memo_test.go index 37abc238..7f465aff 100644 --- a/test/store/memo_test.go +++ b/test/store/memo_test.go @@ -33,8 +33,8 @@ func TestMemoStore(t *testing.T) { CreatorID: &user.ID, }) require.NoError(t, err) - require.Equal(t, 1, len(memoList)) - require.Equal(t, memo, memoList[0]) + require.Equal(t, 2, len(memoList)) + require.Equal(t, memo, memoList[1]) err = store.DeleteMemo(ctx, &api.MemoDelete{ ID: memo.ID, })