chore: use tx for stores

This commit is contained in:
boojack 2022-08-07 10:17:12 +08:00
parent 8c28721839
commit d8e10ba399
9 changed files with 279 additions and 197 deletions

View file

@ -15,6 +15,7 @@ import (
func (s *Server) registerMemoRoutes(g *echo.Group) {
g.POST("/memo", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -31,7 +32,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoCreate.Visibility = &private
}
memo, err := s.Store.CreateMemo(memoCreate)
memo, err := s.Store.CreateMemo(ctx, memoCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err)
}
@ -44,6 +45,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.PATCH("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -56,7 +58,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err)
}
memo, err := s.Store.PatchMemo(memoPatch)
memo, err := s.Store.PatchMemo(ctx, memoPatch)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err)
}
@ -69,6 +71,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.GET("/memo", func(c echo.Context) error {
ctx := c.Request().Context()
memoFind := &api.MemoFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
@ -118,7 +121,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind.Offset = offset
}
list, err := s.Store.FindMemoList(memoFind)
list, err := s.Store.FindMemoList(ctx, memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err)
}
@ -131,6 +134,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.POST("/memo/:memoId/organizer", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -148,12 +152,12 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err)
}
err = s.Store.UpsertMemoOrganizer(memoOrganizerUpsert)
err = s.Store.UpsertMemoOrganizer(ctx, memoOrganizerUpsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err)
}
memo, err := s.Store.FindMemo(&api.MemoFind{
memo, err := s.Store.FindMemo(ctx, &api.MemoFind{
ID: &memoID,
})
if err != nil {
@ -172,6 +176,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.GET("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -180,7 +185,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind := &api.MemoFind{
ID: &memoID,
}
memo, err := s.Store.FindMemo(memoFind)
memo, err := s.Store.FindMemo(ctx, memoFind)
if err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err)
@ -197,6 +202,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.DELETE("/memo/:memoId", func(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := strconv.Atoi(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
@ -205,7 +211,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoDelete := &api.MemoDelete{
ID: memoID,
}
if err := s.Store.DeleteMemo(memoDelete); err != nil {
if err := s.Store.DeleteMemo(ctx, memoDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err)
}
@ -213,6 +219,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
})
g.GET("/memo/amount", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -223,7 +230,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
RowStatus: &normalRowStatus,
}
memoList, err := s.Store.FindMemoList(memoFind)
memoList, err := s.Store.FindMemoList(ctx, memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
}

View file

@ -14,6 +14,7 @@ import (
func (s *Server) registerResourceRoutes(g *echo.Group) {
g.POST("/resource", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -51,7 +52,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
CreatorID: userID,
}
resource, err := s.Store.CreateResource(resourceCreate)
resource, err := s.Store.CreateResource(ctx, resourceCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
}
@ -64,6 +65,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
})
g.GET("/resource", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -71,7 +73,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
resourceFind := &api.ResourceFind{
CreatorID: &userID,
}
list, err := s.Store.FindResourceList(resourceFind)
list, err := s.Store.FindResourceList(ctx, resourceFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err)
}
@ -84,6 +86,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
})
g.GET("/resource/:resourceId", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -97,7 +100,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: &resourceID,
CreatorID: &userID,
}
resource, err := s.Store.FindResource(resourceFind)
resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err)
}
@ -110,6 +113,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
})
g.GET("/resource/:resourceId/blob", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -123,7 +127,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: &resourceID,
CreatorID: &userID,
}
resource, err := s.Store.FindResource(resourceFind)
resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err)
}
@ -138,6 +142,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
})
g.DELETE("/resource/:resourceId", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -152,7 +157,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
ID: resourceID,
CreatorID: userID,
}
if err := s.Store.DeleteResource(resourceDelete); err != nil {
if err := s.Store.DeleteResource(ctx, resourceDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err)
}

View file

@ -13,6 +13,7 @@ import (
func (s *Server) registerShortcutRoutes(g *echo.Group) {
g.POST("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
@ -24,7 +25,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err)
}
shortcut, err := s.Store.CreateShortcut(shortcutCreate)
shortcut, err := s.Store.CreateShortcut(ctx, shortcutCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err)
}
@ -37,6 +38,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
})
g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -49,7 +51,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err)
}
shortcut, err := s.Store.PatchShortcut(shortcutPatch)
shortcut, err := s.Store.PatchShortcut(ctx, shortcutPatch)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err)
}
@ -62,6 +64,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
})
g.GET("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutFind := &api.ShortcutFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
@ -75,7 +78,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutFind.CreatorID = &userID
}
list, err := s.Store.FindShortcutList(shortcutFind)
list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err)
}
@ -88,6 +91,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
})
g.GET("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -96,7 +100,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutFind := &api.ShortcutFind{
ID: &shortcutID,
}
shortcut, err := s.Store.FindShortcut(shortcutFind)
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err)
}
@ -109,6 +113,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
})
g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error {
ctx := c.Request().Context()
shortcutID, err := strconv.Atoi(c.Param("shortcutId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
@ -117,7 +122,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
shortcutDelete := &api.ShortcutDelete{
ID: shortcutID,
}
if err := s.Store.DeleteShortcut(shortcutDelete); err != nil {
if err := s.Store.DeleteShortcut(ctx, shortcutDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err)
}

View file

@ -14,6 +14,7 @@ import (
func (s *Server) registerTagRoutes(g *echo.Group) {
g.GET("/tag", func(c echo.Context) error {
ctx := c.Request().Context()
contentSearch := "#"
normalRowStatus := api.Normal
memoFind := api.MemoFind{
@ -39,7 +40,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
}
}
memoList, err := s.Store.FindMemoList(&memoFind)
memoList, err := s.Store.FindMemoList(ctx, &memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
}

View file

@ -16,6 +16,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) {
})
g.GET("/r/:resourceId/:filename", func(c echo.Context) error {
ctx := c.Request().Context()
resourceID, err := strconv.Atoi(c.Param("resourceId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
@ -26,7 +27,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) {
ID: &resourceID,
Filename: &filename,
}
resource, err := s.Store.FindResource(resourceFind)
resource, err := s.Store.FindResource(ctx, resourceFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch resource ID: %v", resourceID)).SetInternal(err)
}

View file

@ -1,6 +1,7 @@
package store
import (
"context"
"database/sql"
"fmt"
"strings"
@ -43,13 +44,39 @@ func (raw *memoRaw) toMemo() *api.Memo {
}
}
func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) {
memoRaw, err := createMemoRaw(s.db, create)
func (s *Store) composeMemo(ctx context.Context, raw *memoRaw) (*api.Memo, error) {
memo := raw.toMemo()
memoOrganizer, err := s.FindMemoOrganizer(ctx, &api.MemoOrganizerFind{
MemoID: memo.ID,
UserID: memo.CreatorID,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return nil, err
} else if memoOrganizer != nil {
memo.Pinned = memoOrganizer.Pinned
}
return memo, nil
}
func (s *Store) CreateMemo(ctx context.Context, create *api.MemoCreate) (*api.Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRaw, err := createMemoRaw(ctx, tx, create)
if err != nil {
return nil, err
}
memo, err := s.composeMemo(memoRaw)
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
memo, err := s.composeMemo(ctx, memoRaw)
if err != nil {
return nil, err
}
@ -61,13 +88,23 @@ func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) {
return memo, nil
}
func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) {
memoRaw, err := patchMemoRaw(s.db, patch)
func (s *Store) PatchMemo(ctx context.Context, patch *api.MemoPatch) (*api.Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRaw, err := patchMemoRaw(ctx, tx, patch)
if err != nil {
return nil, err
}
memo, err := s.composeMemo(memoRaw)
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
memo, err := s.composeMemo(ctx, memoRaw)
if err != nil {
return nil, err
}
@ -79,15 +116,21 @@ func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) {
return memo, nil
}
func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) {
memoRawList, err := findMemoRawList(s.db, find)
func (s *Store) FindMemoList(ctx context.Context, find *api.MemoFind) ([]*api.Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoRawList, err := findMemoRawList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.Memo{}
for _, raw := range memoRawList {
memo, err := s.composeMemo(raw)
memo, err := s.composeMemo(ctx, raw)
if err != nil {
return nil, err
}
@ -98,7 +141,7 @@ func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) {
return list, nil
}
func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
func (s *Store) FindMemo(ctx context.Context, find *api.MemoFind) (*api.Memo, error) {
if find.ID != nil {
memo := &api.Memo{}
has, err := s.cache.FindCache(api.MemoCache, *find.ID, memo)
@ -110,7 +153,13 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
}
}
list, err := findMemoRawList(s.db, find)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findMemoRawList(ctx, tx, find)
if err != nil {
return nil, err
}
@ -119,7 +168,7 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
memo, err := s.composeMemo(list[0])
memo, err := s.composeMemo(ctx, list[0])
if err != nil {
return nil, err
}
@ -131,18 +180,27 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) {
return memo, nil
}
func (s *Store) DeleteMemo(delete *api.MemoDelete) error {
err := deleteMemo(s.db, delete)
func (s *Store) DeleteMemo(ctx context.Context, delete *api.MemoDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := deleteMemo(ctx, tx, delete); err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.MemoCache, delete.ID)
return nil
}
func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
func createMemoRaw(ctx context.Context, tx *sql.Tx, create *api.MemoCreate) (*memoRaw, error) {
set := []string{"creator_id", "content"}
placeholder := []string{"?", "?"}
args := []interface{}{create.CreatorID, create.Content}
@ -155,22 +213,14 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
}
query := `
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility`
row, err := db.Query(query,
args...,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility
`
var memoRaw memoRaw
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&memoRaw.ID,
&memoRaw.CreatorID,
&memoRaw.CreatedTs,
@ -185,7 +235,7 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
return &memoRaw, nil
}
func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
func patchMemoRaw(ctx context.Context, tx *sql.Tx, patch *api.MemoPatch) (*memoRaw, error) {
set, args := []string{}, []interface{}{}
if v := patch.Content; v != nil {
@ -200,21 +250,14 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
args = append(args, patch.ID)
row, err := db.Query(`
query := `
UPDATE memo
SET `+strings.Join(set, ", ")+`
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility
`, args...)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
`
var memoRaw memoRaw
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&memoRaw.ID,
&memoRaw.CreatorID,
&memoRaw.CreatedTs,
@ -229,7 +272,7 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
return &memoRaw, nil
}
func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
func findMemoRawList(ctx context.Context, tx *sql.Tx, find *api.MemoFind) ([]*memoRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -264,7 +307,7 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
}
}
rows, err := db.Query(`
query := `
SELECT
id,
creator_id,
@ -274,10 +317,10 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
content,
visibility
FROM memo
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`+pagination,
args...,
)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC
` + pagination
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
@ -308,8 +351,8 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
return memoRawList, nil
}
func deleteMemo(db *sql.DB, delete *api.MemoDelete) error {
result, err := db.Exec(`
func deleteMemo(ctx context.Context, tx *sql.Tx, delete *api.MemoDelete) error {
_, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON;
DELETE FROM memo WHERE id = ?
`, delete.ID)
@ -317,26 +360,5 @@ func deleteMemo(db *sql.DB, delete *api.MemoDelete) error {
return FormatError(err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo ID not found: %d", delete.ID)}
}
return nil
}
func (s *Store) composeMemo(raw *memoRaw) (*api.Memo, error) {
memo := raw.toMemo()
memoOrganizer, err := s.FindMemoOrganizer(&api.MemoOrganizerFind{
MemoID: memo.ID,
UserID: memo.CreatorID,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return nil, err
} else if memoOrganizer != nil {
memo.Pinned = memoOrganizer.Pinned
}
return memo, nil
}

View file

@ -1,6 +1,7 @@
package store
import (
"context"
"database/sql"
"fmt"
@ -29,8 +30,14 @@ func (raw *memoOrganizerRaw) toMemoOrganizer() *api.MemoOrganizer {
}
}
func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) {
memoOrganizerRaw, err := findMemoOrganizer(s.db, find)
func (s *Store) FindMemoOrganizer(ctx context.Context, find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
memoOrganizerRaw, err := findMemoOrganizer(ctx, tx, find)
if err != nil {
return nil, err
}
@ -40,17 +47,26 @@ func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganiz
return memoOrganizer, nil
}
func (s *Store) UpsertMemoOrganizer(upsert *api.MemoOrganizerUpsert) error {
err := upsertMemoOrganizer(s.db, upsert)
func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *api.MemoOrganizerUpsert) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := upsertMemoOrganizer(ctx, tx, upsert); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
return nil
}
func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) {
row, err := db.Query(`
func findMemoOrganizer(ctx context.Context, tx *sql.Tx, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) {
query := `
SELECT
id,
memo_id,
@ -58,7 +74,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR
pinned
FROM memo_organizer
WHERE memo_id = ? AND user_id = ?
`, find.MemoID, find.UserID)
`
row, err := tx.QueryContext(ctx, query, find.MemoID, find.UserID)
if err != nil {
return nil, FormatError(err)
}
@ -81,8 +98,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR
return &memoOrganizerRaw, nil
}
func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error {
row, err := db.Query(`
func upsertMemoOrganizer(ctx context.Context, tx *sql.Tx, upsert *api.MemoOrganizerUpsert) error {
query := `
INSERT INTO memo_organizer (
memo_id,
user_id,
@ -93,20 +110,9 @@ func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error {
SET
pinned = EXCLUDED.pinned
RETURNING id, memo_id, user_id, pinned
`,
upsert.MemoID,
upsert.UserID,
upsert.Pinned,
)
if err != nil {
return FormatError(err)
}
defer row.Close()
row.Next()
`
var memoOrganizer api.MemoOrganizer
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned).Scan(
&memoOrganizer.ID,
&memoOrganizer.MemoID,
&memoOrganizer.UserID,

View file

@ -1,6 +1,7 @@
package store
import (
"context"
"database/sql"
"fmt"
"strings"
@ -43,12 +44,22 @@ func (raw *resourceRaw) toResource() *api.Resource {
}
}
func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error) {
resourceRaw, err := createResource(s.db, create)
func (s *Store) CreateResource(ctx context.Context, create *api.ResourceCreate) (*api.Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
resourceRaw, err := createResource(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
resource := resourceRaw.toResource()
if err := s.cache.UpsertCache(api.ResourceCache, resource.ID, resource); err != nil {
@ -58,8 +69,14 @@ func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error
return resource, nil
}
func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error) {
resourceRawList, err := findResourceList(s.db, find)
func (s *Store) FindResourceList(ctx context.Context, find *api.ResourceFind) ([]*api.Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
resourceRawList, err := findResourceList(ctx, tx, find)
if err != nil {
return nil, err
}
@ -72,7 +89,7 @@ func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error
return resourceList, nil
}
func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) {
func (s *Store) FindResource(ctx context.Context, find *api.ResourceFind) (*api.Resource, error) {
if find.ID != nil {
resource := &api.Resource{}
has, err := s.cache.FindCache(api.ResourceCache, *find.ID, resource)
@ -84,7 +101,13 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) {
}
}
list, err := findResourceList(s.db, find)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findResourceList(ctx, tx, find)
if err != nil {
return nil, err
}
@ -102,19 +125,29 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) {
return resource, nil
}
func (s *Store) DeleteResource(delete *api.ResourceDelete) error {
err := deleteResource(s.db, delete)
func (s *Store) DeleteResource(ctx context.Context, delete *api.ResourceDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
err = deleteResource(ctx, tx, delete)
if err != nil {
return err
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.ResourceCache, delete.ID)
return nil
}
func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error) {
row, err := db.Query(`
func createResource(ctx context.Context, tx *sql.Tx, create *api.ResourceCreate) (*resourceRaw, error) {
query := `
INSERT INTO resource (
filename,
blob,
@ -124,21 +157,9 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error
)
VALUES (?, ?, ?, ?, ?)
RETURNING id, filename, blob, type, size, creator_id, created_ts, updated_ts
`,
create.Filename,
create.Blob,
create.Type,
create.Size,
create.CreatorID,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
`
var resourceRaw resourceRaw
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID).Scan(
&resourceRaw.ID,
&resourceRaw.Filename,
&resourceRaw.Blob,
@ -154,7 +175,7 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error
return &resourceRaw, nil
}
func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error) {
func findResourceList(ctx context.Context, tx *sql.Tx, find *api.ResourceFind) ([]*resourceRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -167,7 +188,7 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
where, args = append(where, "filename = ?"), append(args, *v)
}
rows, err := db.Query(`
query := `
SELECT
id,
filename,
@ -178,10 +199,10 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
created_ts,
updated_ts
FROM resource
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
@ -213,8 +234,8 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error
return resourceRawList, nil
}
func deleteResource(db *sql.DB, delete *api.ResourceDelete) error {
result, err := db.Exec(`
func deleteResource(ctx context.Context, tx *sql.Tx, delete *api.ResourceDelete) error {
_, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON;
DELETE FROM resource WHERE id = ? AND creator_id = ?
`, delete.ID, delete.CreatorID)
@ -222,10 +243,5 @@ func deleteResource(db *sql.DB, delete *api.ResourceDelete) error {
return FormatError(err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("resource ID not found: %d", delete.ID)}
}
return nil
}

View file

@ -1,6 +1,7 @@
package store
import (
"context"
"database/sql"
"fmt"
"strings"
@ -39,12 +40,22 @@ func (raw *shortcutRaw) toShortcut() *api.Shortcut {
}
}
func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error) {
shortcutRaw, err := createShortcut(s.db, create)
func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRaw, err := createShortcut(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
shortcut := shortcutRaw.toShortcut()
if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil {
@ -54,12 +65,22 @@ func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error
return shortcut, nil
}
func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) {
shortcutRaw, err := patchShortcut(s.db, patch)
func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRaw, err := patchShortcut(ctx, tx, patch)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
shortcut := shortcutRaw.toShortcut()
if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil {
@ -69,8 +90,14 @@ func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) {
return shortcut, nil
}
func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error) {
shortcutRawList, err := findShortcutList(s.db, find)
func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRawList, err := findShortcutList(ctx, tx, find)
if err != nil {
return nil, err
}
@ -83,7 +110,7 @@ func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error
return list, nil
}
func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) {
func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) {
if find.ID != nil {
shortcut := &api.Shortcut{}
has, err := s.cache.FindCache(api.ShortcutCache, *find.ID, shortcut)
@ -95,7 +122,13 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) {
}
}
list, err := findShortcutList(s.db, find)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findShortcutList(ctx, tx, find)
if err != nil {
return nil, err
}
@ -113,19 +146,29 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) {
return shortcut, nil
}
func (s *Store) DeleteShortcut(delete *api.ShortcutDelete) error {
err := deleteShortcut(s.db, delete)
func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
err = deleteShortcut(ctx, tx, delete)
if err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.cache.DeleteCache(api.ShortcutCache, delete.ID)
return nil
}
func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error) {
row, err := db.Query(`
func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) {
query := `
INSERT INTO shortcut (
title,
payload,
@ -133,19 +176,9 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error
)
VALUES (?, ?, ?)
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
`,
create.Title,
create.Payload,
create.CreatorID,
)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
row.Next()
`
var shortcutRaw shortcutRaw
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
&shortcutRaw.ID,
&shortcutRaw.Title,
&shortcutRaw.Payload,
@ -160,7 +193,7 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error
return &shortcutRaw, nil
}
func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) {
set, args := []string{}, []interface{}{}
if v := patch.Title; v != nil {
@ -175,23 +208,14 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
args = append(args, patch.ID)
row, err := db.Query(`
query := `
UPDATE shortcut
SET `+strings.Join(set, ", ")+`
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, title, payload, created_ts, updated_ts, row_status
`, args...)
if err != nil {
return nil, FormatError(err)
}
defer row.Close()
if !row.Next() {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
`
var shortcutRaw shortcutRaw
if err := row.Scan(
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcutRaw.ID,
&shortcutRaw.Title,
&shortcutRaw.Payload,
@ -205,7 +229,7 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
return &shortcutRaw, nil
}
func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error) {
func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -218,7 +242,7 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error
where, args = append(where, "title = ?"), append(args, *v)
}
rows, err := db.Query(`
rows, err := tx.QueryContext(ctx, `
SELECT
id,
title,
@ -262,8 +286,8 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error
return shortcutRawList, nil
}
func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error {
result, err := db.Exec(`
func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error {
_, err := tx.ExecContext(ctx, `
PRAGMA foreign_keys = ON;
DELETE FROM shortcut WHERE id = ?
`, delete.ID)
@ -271,10 +295,5 @@ func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error {
return FormatError(err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut ID not found: %d", delete.ID)}
}
return nil
}