diff --git a/server/memo.go b/server/memo.go index 62b90586..75671295 100644 --- a/server/memo.go +++ b/server/memo.go @@ -15,6 +15,7 @@ import ( func (s *Server) registerMemoRoutes(g *echo.Group) { g.POST("/memo", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -31,7 +32,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoCreate.Visibility = &private } - memo, err := s.Store.CreateMemo(memoCreate) + memo, err := s.Store.CreateMemo(ctx, memoCreate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) } @@ -44,6 +45,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.PATCH("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() memoID, err := strconv.Atoi(c.Param("memoId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) @@ -56,7 +58,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err) } - memo, err := s.Store.PatchMemo(memoPatch) + memo, err := s.Store.PatchMemo(ctx, memoPatch) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err) } @@ -69,6 +71,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.GET("/memo", func(c echo.Context) error { + ctx := c.Request().Context() memoFind := &api.MemoFind{} if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { @@ -118,7 +121,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoFind.Offset = offset } - list, err := s.Store.FindMemoList(memoFind) + list, err := s.Store.FindMemoList(ctx, memoFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err) } @@ -131,6 +134,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.POST("/memo/:memoId/organizer", func(c echo.Context) error { + ctx := c.Request().Context() memoID, err := strconv.Atoi(c.Param("memoId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) @@ -148,12 +152,12 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err) } - err = s.Store.UpsertMemoOrganizer(memoOrganizerUpsert) + err = s.Store.UpsertMemoOrganizer(ctx, memoOrganizerUpsert) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err) } - memo, err := s.Store.FindMemo(&api.MemoFind{ + memo, err := s.Store.FindMemo(ctx, &api.MemoFind{ ID: &memoID, }) if err != nil { @@ -172,6 +176,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.GET("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() memoID, err := strconv.Atoi(c.Param("memoId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) @@ -180,7 +185,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoFind := &api.MemoFind{ ID: &memoID, } - memo, err := s.Store.FindMemo(memoFind) + memo, err := s.Store.FindMemo(ctx, memoFind) if err != nil { if common.ErrorCode(err) == common.NotFound { return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) @@ -197,6 +202,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.DELETE("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() memoID, err := strconv.Atoi(c.Param("memoId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) @@ -205,7 +211,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoDelete := &api.MemoDelete{ ID: memoID, } - if err := s.Store.DeleteMemo(memoDelete); err != nil { + if err := s.Store.DeleteMemo(ctx, memoDelete); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) } @@ -213,6 +219,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { }) g.GET("/memo/amount", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -223,7 +230,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { RowStatus: &normalRowStatus, } - memoList, err := s.Store.FindMemoList(memoFind) + memoList, err := s.Store.FindMemoList(ctx, memoFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) } diff --git a/server/resource.go b/server/resource.go index 0b95d0ce..2cec7880 100644 --- a/server/resource.go +++ b/server/resource.go @@ -14,6 +14,7 @@ import ( func (s *Server) registerResourceRoutes(g *echo.Group) { g.POST("/resource", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -51,7 +52,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { CreatorID: userID, } - resource, err := s.Store.CreateResource(resourceCreate) + resource, err := s.Store.CreateResource(ctx, resourceCreate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) } @@ -64,6 +65,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { }) g.GET("/resource", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -71,7 +73,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { resourceFind := &api.ResourceFind{ CreatorID: &userID, } - list, err := s.Store.FindResourceList(resourceFind) + list, err := s.Store.FindResourceList(ctx, resourceFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) } @@ -84,6 +86,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { }) g.GET("/resource/:resourceId", func(c echo.Context) error { + ctx := c.Request().Context() resourceID, err := strconv.Atoi(c.Param("resourceId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) @@ -97,7 +100,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ID: &resourceID, CreatorID: &userID, } - resource, err := s.Store.FindResource(resourceFind) + resource, err := s.Store.FindResource(ctx, resourceFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) } @@ -110,6 +113,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { }) g.GET("/resource/:resourceId/blob", func(c echo.Context) error { + ctx := c.Request().Context() resourceID, err := strconv.Atoi(c.Param("resourceId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) @@ -123,7 +127,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ID: &resourceID, CreatorID: &userID, } - resource, err := s.Store.FindResource(resourceFind) + resource, err := s.Store.FindResource(ctx, resourceFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) } @@ -138,6 +142,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { }) g.DELETE("/resource/:resourceId", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -152,7 +157,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ID: resourceID, CreatorID: userID, } - if err := s.Store.DeleteResource(resourceDelete); err != nil { + if err := s.Store.DeleteResource(ctx, resourceDelete); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err) } diff --git a/server/shortcut.go b/server/shortcut.go index fab14e89..c966ff6e 100644 --- a/server/shortcut.go +++ b/server/shortcut.go @@ -13,6 +13,7 @@ import ( func (s *Server) registerShortcutRoutes(g *echo.Group) { g.POST("/shortcut", func(c echo.Context) error { + ctx := c.Request().Context() userID, ok := c.Get(getUserIDContextKey()).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") @@ -24,7 +25,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err) } - shortcut, err := s.Store.CreateShortcut(shortcutCreate) + shortcut, err := s.Store.CreateShortcut(ctx, shortcutCreate) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err) } @@ -37,6 +38,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { }) g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error { + ctx := c.Request().Context() shortcutID, err := strconv.Atoi(c.Param("shortcutId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) @@ -49,7 +51,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err) } - shortcut, err := s.Store.PatchShortcut(shortcutPatch) + shortcut, err := s.Store.PatchShortcut(ctx, shortcutPatch) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err) } @@ -62,6 +64,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { }) g.GET("/shortcut", func(c echo.Context) error { + ctx := c.Request().Context() shortcutFind := &api.ShortcutFind{} if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { @@ -75,7 +78,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { shortcutFind.CreatorID = &userID } - list, err := s.Store.FindShortcutList(shortcutFind) + list, err := s.Store.FindShortcutList(ctx, shortcutFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) } @@ -88,6 +91,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { }) g.GET("/shortcut/:shortcutId", func(c echo.Context) error { + ctx := c.Request().Context() shortcutID, err := strconv.Atoi(c.Param("shortcutId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) @@ -96,7 +100,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { shortcutFind := &api.ShortcutFind{ ID: &shortcutID, } - shortcut, err := s.Store.FindShortcut(shortcutFind) + shortcut, err := s.Store.FindShortcut(ctx, shortcutFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err) } @@ -109,6 +113,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { }) g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error { + ctx := c.Request().Context() shortcutID, err := strconv.Atoi(c.Param("shortcutId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) @@ -117,7 +122,7 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { shortcutDelete := &api.ShortcutDelete{ ID: shortcutID, } - if err := s.Store.DeleteShortcut(shortcutDelete); err != nil { + if err := s.Store.DeleteShortcut(ctx, shortcutDelete); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err) } diff --git a/server/tag.go b/server/tag.go index 8c755eef..e2b7227d 100644 --- a/server/tag.go +++ b/server/tag.go @@ -14,6 +14,7 @@ import ( func (s *Server) registerTagRoutes(g *echo.Group) { g.GET("/tag", func(c echo.Context) error { + ctx := c.Request().Context() contentSearch := "#" normalRowStatus := api.Normal memoFind := api.MemoFind{ @@ -39,7 +40,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) { } } - memoList, err := s.Store.FindMemoList(&memoFind) + memoList, err := s.Store.FindMemoList(ctx, &memoFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) } diff --git a/server/webhook.go b/server/webhook.go index da77ddb5..975f7cb9 100644 --- a/server/webhook.go +++ b/server/webhook.go @@ -16,6 +16,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) { }) g.GET("/r/:resourceId/:filename", func(c echo.Context) error { + ctx := c.Request().Context() resourceID, err := strconv.Atoi(c.Param("resourceId")) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) @@ -26,7 +27,7 @@ func (s *Server) registerWebhookRoutes(g *echo.Group) { ID: &resourceID, Filename: &filename, } - resource, err := s.Store.FindResource(resourceFind) + resource, err := s.Store.FindResource(ctx, resourceFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch resource ID: %v", resourceID)).SetInternal(err) } diff --git a/store/memo.go b/store/memo.go index ff1fdf65..25c0c832 100644 --- a/store/memo.go +++ b/store/memo.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" "strings" @@ -43,13 +44,39 @@ func (raw *memoRaw) toMemo() *api.Memo { } } -func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) { - memoRaw, err := createMemoRaw(s.db, create) +func (s *Store) composeMemo(ctx context.Context, raw *memoRaw) (*api.Memo, error) { + memo := raw.toMemo() + + memoOrganizer, err := s.FindMemoOrganizer(ctx, &api.MemoOrganizerFind{ + MemoID: memo.ID, + UserID: memo.CreatorID, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return nil, err + } else if memoOrganizer != nil { + memo.Pinned = memoOrganizer.Pinned + } + + return memo, nil +} + +func (s *Store) CreateMemo(ctx context.Context, create *api.MemoCreate) (*api.Memo, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoRaw, err := createMemoRaw(ctx, tx, create) if err != nil { return nil, err } - memo, err := s.composeMemo(memoRaw) + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + + memo, err := s.composeMemo(ctx, memoRaw) if err != nil { return nil, err } @@ -61,13 +88,23 @@ func (s *Store) CreateMemo(create *api.MemoCreate) (*api.Memo, error) { return memo, nil } -func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) { - memoRaw, err := patchMemoRaw(s.db, patch) +func (s *Store) PatchMemo(ctx context.Context, patch *api.MemoPatch) (*api.Memo, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoRaw, err := patchMemoRaw(ctx, tx, patch) if err != nil { return nil, err } - memo, err := s.composeMemo(memoRaw) + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + + memo, err := s.composeMemo(ctx, memoRaw) if err != nil { return nil, err } @@ -79,15 +116,21 @@ func (s *Store) PatchMemo(patch *api.MemoPatch) (*api.Memo, error) { return memo, nil } -func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) { - memoRawList, err := findMemoRawList(s.db, find) +func (s *Store) FindMemoList(ctx context.Context, find *api.MemoFind) ([]*api.Memo, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoRawList, err := findMemoRawList(ctx, tx, find) if err != nil { return nil, err } list := []*api.Memo{} for _, raw := range memoRawList { - memo, err := s.composeMemo(raw) + memo, err := s.composeMemo(ctx, raw) if err != nil { return nil, err } @@ -98,7 +141,7 @@ func (s *Store) FindMemoList(find *api.MemoFind) ([]*api.Memo, error) { return list, nil } -func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) { +func (s *Store) FindMemo(ctx context.Context, find *api.MemoFind) (*api.Memo, error) { if find.ID != nil { memo := &api.Memo{} has, err := s.cache.FindCache(api.MemoCache, *find.ID, memo) @@ -110,7 +153,13 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) { } } - list, err := findMemoRawList(s.db, find) + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := findMemoRawList(ctx, tx, find) if err != nil { return nil, err } @@ -119,7 +168,7 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) { return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} } - memo, err := s.composeMemo(list[0]) + memo, err := s.composeMemo(ctx, list[0]) if err != nil { return nil, err } @@ -131,18 +180,27 @@ func (s *Store) FindMemo(find *api.MemoFind) (*api.Memo, error) { return memo, nil } -func (s *Store) DeleteMemo(delete *api.MemoDelete) error { - err := deleteMemo(s.db, delete) +func (s *Store) DeleteMemo(ctx context.Context, delete *api.MemoDelete) error { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return FormatError(err) } + defer tx.Rollback() + + if err := deleteMemo(ctx, tx, delete); err != nil { + return FormatError(err) + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } s.cache.DeleteCache(api.MemoCache, delete.ID) return nil } -func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) { +func createMemoRaw(ctx context.Context, tx *sql.Tx, create *api.MemoCreate) (*memoRaw, error) { set := []string{"creator_id", "content"} placeholder := []string{"?", "?"} args := []interface{}{create.CreatorID, create.Content} @@ -155,22 +213,14 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) { } query := ` - INSERT INTO memo ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) - RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility` - row, err := db.Query(query, - args..., - ) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - row.Next() + INSERT INTO memo ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility + ` var memoRaw memoRaw - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, args...).Scan( &memoRaw.ID, &memoRaw.CreatorID, &memoRaw.CreatedTs, @@ -185,7 +235,7 @@ func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) { return &memoRaw, nil } -func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) { +func patchMemoRaw(ctx context.Context, tx *sql.Tx, patch *api.MemoPatch) (*memoRaw, error) { set, args := []string{}, []interface{}{} if v := patch.Content; v != nil { @@ -200,21 +250,14 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) { args = append(args, patch.ID) - row, err := db.Query(` + query := ` UPDATE memo - SET `+strings.Join(set, ", ")+` + SET ` + strings.Join(set, ", ") + ` WHERE id = ? RETURNING id, creator_id, created_ts, updated_ts, row_status, content, visibility - `, args...) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - row.Next() - + ` var memoRaw memoRaw - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, args...).Scan( &memoRaw.ID, &memoRaw.CreatorID, &memoRaw.CreatedTs, @@ -229,7 +272,7 @@ func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) { return &memoRaw, nil } -func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { +func findMemoRawList(ctx context.Context, tx *sql.Tx, find *api.MemoFind) ([]*memoRaw, error) { where, args := []string{"1 = 1"}, []interface{}{} if v := find.ID; v != nil { @@ -264,7 +307,7 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { } } - rows, err := db.Query(` + query := ` SELECT id, creator_id, @@ -274,10 +317,10 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { content, visibility FROM memo - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`+pagination, - args..., - ) + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC + ` + pagination + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, FormatError(err) } @@ -308,8 +351,8 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { return memoRawList, nil } -func deleteMemo(db *sql.DB, delete *api.MemoDelete) error { - result, err := db.Exec(` +func deleteMemo(ctx context.Context, tx *sql.Tx, delete *api.MemoDelete) error { + _, err := tx.ExecContext(ctx, ` PRAGMA foreign_keys = ON; DELETE FROM memo WHERE id = ? `, delete.ID) @@ -317,26 +360,5 @@ func deleteMemo(db *sql.DB, delete *api.MemoDelete) error { return FormatError(err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo ID not found: %d", delete.ID)} - } - return nil } - -func (s *Store) composeMemo(raw *memoRaw) (*api.Memo, error) { - memo := raw.toMemo() - - memoOrganizer, err := s.FindMemoOrganizer(&api.MemoOrganizerFind{ - MemoID: memo.ID, - UserID: memo.CreatorID, - }) - if err != nil && common.ErrorCode(err) != common.NotFound { - return nil, err - } else if memoOrganizer != nil { - memo.Pinned = memoOrganizer.Pinned - } - - return memo, nil -} diff --git a/store/memo_organizer.go b/store/memo_organizer.go index 666dbbe4..9aadb8a2 100644 --- a/store/memo_organizer.go +++ b/store/memo_organizer.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" @@ -29,8 +30,14 @@ func (raw *memoOrganizerRaw) toMemoOrganizer() *api.MemoOrganizer { } } -func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) { - memoOrganizerRaw, err := findMemoOrganizer(s.db, find) +func (s *Store) FindMemoOrganizer(ctx context.Context, find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + memoOrganizerRaw, err := findMemoOrganizer(ctx, tx, find) if err != nil { return nil, err } @@ -40,17 +47,26 @@ func (s *Store) FindMemoOrganizer(find *api.MemoOrganizerFind) (*api.MemoOrganiz return memoOrganizer, nil } -func (s *Store) UpsertMemoOrganizer(upsert *api.MemoOrganizerUpsert) error { - err := upsertMemoOrganizer(s.db, upsert) +func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *api.MemoOrganizerUpsert) error { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + if err := upsertMemoOrganizer(ctx, tx, upsert); err != nil { return err } + if err := tx.Commit(); err != nil { + return FormatError(err) + } + return nil } -func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) { - row, err := db.Query(` +func findMemoOrganizer(ctx context.Context, tx *sql.Tx, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) { + query := ` SELECT id, memo_id, @@ -58,7 +74,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR pinned FROM memo_organizer WHERE memo_id = ? AND user_id = ? - `, find.MemoID, find.UserID) + ` + row, err := tx.QueryContext(ctx, query, find.MemoID, find.UserID) if err != nil { return nil, FormatError(err) } @@ -81,8 +98,8 @@ func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerR return &memoOrganizerRaw, nil } -func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error { - row, err := db.Query(` +func upsertMemoOrganizer(ctx context.Context, tx *sql.Tx, upsert *api.MemoOrganizerUpsert) error { + query := ` INSERT INTO memo_organizer ( memo_id, user_id, @@ -93,20 +110,9 @@ func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error { SET pinned = EXCLUDED.pinned RETURNING id, memo_id, user_id, pinned - `, - upsert.MemoID, - upsert.UserID, - upsert.Pinned, - ) - if err != nil { - return FormatError(err) - } - - defer row.Close() - - row.Next() + ` var memoOrganizer api.MemoOrganizer - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned).Scan( &memoOrganizer.ID, &memoOrganizer.MemoID, &memoOrganizer.UserID, diff --git a/store/resource.go b/store/resource.go index 68a45bf5..ce6b712a 100644 --- a/store/resource.go +++ b/store/resource.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" "strings" @@ -43,12 +44,22 @@ func (raw *resourceRaw) toResource() *api.Resource { } } -func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error) { - resourceRaw, err := createResource(s.db, create) +func (s *Store) CreateResource(ctx context.Context, create *api.ResourceCreate) (*api.Resource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + resourceRaw, err := createResource(ctx, tx, create) if err != nil { return nil, err } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + resource := resourceRaw.toResource() if err := s.cache.UpsertCache(api.ResourceCache, resource.ID, resource); err != nil { @@ -58,8 +69,14 @@ func (s *Store) CreateResource(create *api.ResourceCreate) (*api.Resource, error return resource, nil } -func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error) { - resourceRawList, err := findResourceList(s.db, find) +func (s *Store) FindResourceList(ctx context.Context, find *api.ResourceFind) ([]*api.Resource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + resourceRawList, err := findResourceList(ctx, tx, find) if err != nil { return nil, err } @@ -72,7 +89,7 @@ func (s *Store) FindResourceList(find *api.ResourceFind) ([]*api.Resource, error return resourceList, nil } -func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) { +func (s *Store) FindResource(ctx context.Context, find *api.ResourceFind) (*api.Resource, error) { if find.ID != nil { resource := &api.Resource{} has, err := s.cache.FindCache(api.ResourceCache, *find.ID, resource) @@ -84,7 +101,13 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) { } } - list, err := findResourceList(s.db, find) + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := findResourceList(ctx, tx, find) if err != nil { return nil, err } @@ -102,19 +125,29 @@ func (s *Store) FindResource(find *api.ResourceFind) (*api.Resource, error) { return resource, nil } -func (s *Store) DeleteResource(delete *api.ResourceDelete) error { - err := deleteResource(s.db, delete) +func (s *Store) DeleteResource(ctx context.Context, delete *api.ResourceDelete) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + err = deleteResource(ctx, tx, delete) if err != nil { return err } + if err := tx.Commit(); err != nil { + return FormatError(err) + } + s.cache.DeleteCache(api.ResourceCache, delete.ID) return nil } -func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error) { - row, err := db.Query(` +func createResource(ctx context.Context, tx *sql.Tx, create *api.ResourceCreate) (*resourceRaw, error) { + query := ` INSERT INTO resource ( filename, blob, @@ -124,21 +157,9 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error ) VALUES (?, ?, ?, ?, ?) RETURNING id, filename, blob, type, size, creator_id, created_ts, updated_ts - `, - create.Filename, - create.Blob, - create.Type, - create.Size, - create.CreatorID, - ) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - row.Next() + ` var resourceRaw resourceRaw - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, create.Filename, create.Blob, create.Type, create.Size, create.CreatorID).Scan( &resourceRaw.ID, &resourceRaw.Filename, &resourceRaw.Blob, @@ -154,7 +175,7 @@ func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error return &resourceRaw, nil } -func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error) { +func findResourceList(ctx context.Context, tx *sql.Tx, find *api.ResourceFind) ([]*resourceRaw, error) { where, args := []string{"1 = 1"}, []interface{}{} if v := find.ID; v != nil { @@ -167,7 +188,7 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error where, args = append(where, "filename = ?"), append(args, *v) } - rows, err := db.Query(` + query := ` SELECT id, filename, @@ -178,10 +199,10 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error created_ts, updated_ts FROM resource - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return nil, FormatError(err) } @@ -213,8 +234,8 @@ func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error return resourceRawList, nil } -func deleteResource(db *sql.DB, delete *api.ResourceDelete) error { - result, err := db.Exec(` +func deleteResource(ctx context.Context, tx *sql.Tx, delete *api.ResourceDelete) error { + _, err := tx.ExecContext(ctx, ` PRAGMA foreign_keys = ON; DELETE FROM resource WHERE id = ? AND creator_id = ? `, delete.ID, delete.CreatorID) @@ -222,10 +243,5 @@ func deleteResource(db *sql.DB, delete *api.ResourceDelete) error { return FormatError(err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("resource ID not found: %d", delete.ID)} - } - return nil } diff --git a/store/shortcut.go b/store/shortcut.go index 2c83ec06..2aeb2d9d 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" "strings" @@ -39,12 +40,22 @@ func (raw *shortcutRaw) toShortcut() *api.Shortcut { } } -func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error) { - shortcutRaw, err := createShortcut(s.db, create) +func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + shortcutRaw, err := createShortcut(ctx, tx, create) if err != nil { return nil, err } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + shortcut := shortcutRaw.toShortcut() if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil { @@ -54,12 +65,22 @@ func (s *Store) CreateShortcut(create *api.ShortcutCreate) (*api.Shortcut, error return shortcut, nil } -func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) { - shortcutRaw, err := patchShortcut(s.db, patch) +func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + shortcutRaw, err := patchShortcut(ctx, tx, patch) if err != nil { return nil, err } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + shortcut := shortcutRaw.toShortcut() if err := s.cache.UpsertCache(api.ShortcutCache, shortcut.ID, shortcut); err != nil { @@ -69,8 +90,14 @@ func (s *Store) PatchShortcut(patch *api.ShortcutPatch) (*api.Shortcut, error) { return shortcut, nil } -func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error) { - shortcutRawList, err := findShortcutList(s.db, find) +func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + shortcutRawList, err := findShortcutList(ctx, tx, find) if err != nil { return nil, err } @@ -83,7 +110,7 @@ func (s *Store) FindShortcutList(find *api.ShortcutFind) ([]*api.Shortcut, error return list, nil } -func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) { +func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) { if find.ID != nil { shortcut := &api.Shortcut{} has, err := s.cache.FindCache(api.ShortcutCache, *find.ID, shortcut) @@ -95,7 +122,13 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) { } } - list, err := findShortcutList(s.db, find) + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := findShortcutList(ctx, tx, find) if err != nil { return nil, err } @@ -113,19 +146,29 @@ func (s *Store) FindShortcut(find *api.ShortcutFind) (*api.Shortcut, error) { return shortcut, nil } -func (s *Store) DeleteShortcut(delete *api.ShortcutDelete) error { - err := deleteShortcut(s.db, delete) +func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return FormatError(err) } + defer tx.Rollback() + + err = deleteShortcut(ctx, tx, delete) + if err != nil { + return FormatError(err) + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } s.cache.DeleteCache(api.ShortcutCache, delete.ID) return nil } -func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error) { - row, err := db.Query(` +func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) { + query := ` INSERT INTO shortcut ( title, payload, @@ -133,19 +176,9 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error ) VALUES (?, ?, ?) RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status - `, - create.Title, - create.Payload, - create.CreatorID, - ) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - row.Next() + ` var shortcutRaw shortcutRaw - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan( &shortcutRaw.ID, &shortcutRaw.Title, &shortcutRaw.Payload, @@ -160,7 +193,7 @@ func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error return &shortcutRaw, nil } -func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) { +func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) { set, args := []string{}, []interface{}{} if v := patch.Title; v != nil { @@ -175,23 +208,14 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) { args = append(args, patch.ID) - row, err := db.Query(` + query := ` UPDATE shortcut - SET `+strings.Join(set, ", ")+` + SET ` + strings.Join(set, ", ") + ` WHERE id = ? RETURNING id, title, payload, created_ts, updated_ts, row_status - `, args...) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - if !row.Next() { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} - } - + ` var shortcutRaw shortcutRaw - if err := row.Scan( + if err := tx.QueryRowContext(ctx, query, args...).Scan( &shortcutRaw.ID, &shortcutRaw.Title, &shortcutRaw.Payload, @@ -205,7 +229,7 @@ func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) { return &shortcutRaw, nil } -func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error) { +func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) { where, args := []string{"1 = 1"}, []interface{}{} if v := find.ID; v != nil { @@ -218,7 +242,7 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error where, args = append(where, "title = ?"), append(args, *v) } - rows, err := db.Query(` + rows, err := tx.QueryContext(ctx, ` SELECT id, title, @@ -262,8 +286,8 @@ func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error return shortcutRawList, nil } -func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error { - result, err := db.Exec(` +func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error { + _, err := tx.ExecContext(ctx, ` PRAGMA foreign_keys = ON; DELETE FROM shortcut WHERE id = ? `, delete.ID) @@ -271,10 +295,5 @@ func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error { return FormatError(err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut ID not found: %d", delete.ID)} - } - return nil }