diff --git a/api/user_setting.go b/api/user_setting.go index ff5b0bb7..fb7490cb 100644 --- a/api/user_setting.go +++ b/api/user_setting.go @@ -3,8 +3,10 @@ package api type UserSettingKey string const ( - // UserSettingLocaleKey is the key type for user locale + // UserSettingLocaleKey is the key type for user locale. UserSettingLocaleKey UserSettingKey = "locale" + // UserSettingMemoVisibilityKey is the key type for user perference memo default visibility. + UserSettingMemoVisibilityKey UserSettingKey = "memo-visibility" ) // String returns the string format of UserSettingKey type. @@ -12,6 +14,8 @@ func (key UserSettingKey) String() string { switch key { case UserSettingLocaleKey: return "locale" + case UserSettingMemoVisibilityKey: + return "memo-visibility" } return "" } @@ -31,4 +35,6 @@ type UserSettingUpsert struct { type UserSettingFind struct { UserID int + + Key *UserSettingKey `json:"key"` } diff --git a/server/memo.go b/server/memo.go index 75671295..367ff7cf 100644 --- a/server/memo.go +++ b/server/memo.go @@ -27,6 +27,18 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo request").SetInternal(err) } + userSettingMemoVisibilityKey := api.UserSettingMemoVisibilityKey + userMemoVisibilitySetting, err := s.Store.FindUserSetting(ctx, &api.UserSettingFind{ + UserID: userID, + Key: &userSettingMemoVisibilityKey, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user setting").SetInternal(err) + } + if userMemoVisibilitySetting != nil { + memoCreate.Visibility = (*api.Visibility)(&userMemoVisibilitySetting.Value) + } + if memoCreate.Visibility == nil || *memoCreate.Visibility == "" { private := api.Privite memoCreate.Visibility = &private diff --git a/store/user_setting.go b/store/user_setting.go index 2f32a1f1..39f69e6d 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -3,6 +3,7 @@ package store import ( "context" "database/sql" + "strings" "github.com/usememos/memos/api" ) @@ -62,6 +63,27 @@ func (s *Store) FindUserSettingList(ctx context.Context, find *api.UserSettingFi return list, nil } +func (s *Store) FindUserSetting(ctx context.Context, find *api.UserSettingFind) (*api.UserSetting, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := findUserSettingList(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + userSetting := list[0].toUserSetting() + + return userSetting, nil +} + func upsertUserSetting(ctx context.Context, tx *sql.Tx, upsert *api.UserSettingUpsert) (*userSettingRaw, error) { query := ` INSERT INTO user_setting ( @@ -86,15 +108,22 @@ func upsertUserSetting(ctx context.Context, tx *sql.Tx, upsert *api.UserSettingU } func findUserSettingList(ctx context.Context, tx *sql.Tx, find *api.UserSettingFind) ([]*userSettingRaw, error) { + where, args := []string{"1 = 1"}, []interface{}{} + + if v := find.Key; v != nil { + where, args = append(where, "key = ?"), append(args, (*v).String()) + } + + where, args = append(where, "user_id = ?"), append(args, find.UserID) + query := ` SELECT user_id, key, value FROM user_setting - WHERE user_id = ? - ` - rows, err := tx.QueryContext(ctx, query, find.UserID) + WHERE ` + strings.Join(where, " AND ") + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, FormatError(err) }