diff --git a/store/activity.go b/store/activity.go index e6c630d7..20da01cc 100644 --- a/store/activity.go +++ b/store/activity.go @@ -2,8 +2,30 @@ package store import ( "context" + + storepb "github.com/usememos/memos/proto/gen/store" ) +type ActivityType string + +const ( + ActivityTypeMemoComment ActivityType = "MEMO_COMMENT" +) + +func (t ActivityType) String() string { + return string(t) +} + +type ActivityLevel string + +const ( + ActivityLevelInfo ActivityLevel = "INFO" +) + +func (l ActivityLevel) String() string { + return string(l) +} + type Activity struct { ID int32 @@ -12,9 +34,9 @@ type Activity struct { CreatedTs int64 // Domain specific fields - Type string - Level string - Payload string + Type ActivityType + Level ActivityLevel + Payload *storepb.ActivityPayload } type FindActivity struct { diff --git a/store/db/mysql/activity.go b/store/db/mysql/activity.go index b985a622..075aed0f 100644 --- a/store/db/mysql/activity.go +++ b/store/db/mysql/activity.go @@ -5,14 +5,24 @@ import ( "strings" "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + payloadString := "{}" + if create.Payload != nil { + bytes, err := protojson.Marshal(create.Payload) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal activity payload") + } + payloadString = string(bytes) + } fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"} placeholder := []string{"?", "?", "?", "?"} - args := []any{create.CreatorID, create.Type, create.Level, create.Payload} + args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} if create.ID != 0 { fields = append(fields, "`id`") @@ -64,17 +74,23 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s list := []*store.Activity{} for rows.Next() { activity := &store.Activity{} + var payloadBytes []byte if err := rows.Scan( &activity.ID, &activity.CreatorID, &activity.Type, &activity.Level, - &activity.Payload, + &payloadBytes, &activity.CreatedTs, ); err != nil { return nil, err } + payload := &storepb.ActivityPayload{} + if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil { + return nil, err + } + activity.Payload = payload list = append(list, activity) } diff --git a/store/db/mysql/common.go b/store/db/mysql/common.go new file mode 100644 index 00000000..37ffe296 --- /dev/null +++ b/store/db/mysql/common.go @@ -0,0 +1,9 @@ +package mysql + +import "google.golang.org/protobuf/encoding/protojson" + +var ( + protojsonUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, + } +) diff --git a/store/db/sqlite/activity.go b/store/db/sqlite/activity.go index 59d9897a..1853df1e 100644 --- a/store/db/sqlite/activity.go +++ b/store/db/sqlite/activity.go @@ -4,13 +4,26 @@ import ( "context" "strings" + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + payloadString := "{}" + if create.Payload != nil { + bytes, err := protojson.Marshal(create.Payload) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal activity payload") + } + payloadString = string(bytes) + } + fields := []string{"`creator_id`", "`type`", "`level`", "`payload`"} placeholder := []string{"?", "?", "?", "?"} - args := []any{create.CreatorID, create.Type, create.Level, create.Payload} + args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} if create.ID != 0 { fields = append(fields, "`id`") @@ -52,17 +65,23 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s list := []*store.Activity{} for rows.Next() { activity := &store.Activity{} + var payloadBytes []byte if err := rows.Scan( &activity.ID, &activity.CreatorID, &activity.Type, &activity.Level, - &activity.Payload, + &payloadBytes, &activity.CreatedTs, ); err != nil { return nil, err } + payload := &storepb.ActivityPayload{} + if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil { + return nil, err + } + activity.Payload = payload list = append(list, activity) } diff --git a/test/store/activity_test.go b/test/store/activity_test.go new file mode 100644 index 00000000..d5b996d2 --- /dev/null +++ b/test/store/activity_test.go @@ -0,0 +1,33 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func TestActivityStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + create := &store.Activity{ + CreatorID: user.ID, + Type: store.ActivityTypeMemoComment, + Level: store.ActivityLevelInfo, + Payload: &storepb.ActivityPayload{}, + } + activity, err := ts.CreateActivity(ctx, create) + require.NoError(t, err) + require.NotNil(t, activity) + activities, err := ts.ListActivities(ctx, &store.FindActivity{ + ID: &activity.ID, + }) + require.NoError(t, err) + require.Equal(t, 1, len(activities)) + require.Equal(t, activity, activities[0]) +}