chore: remove unused transaction in store (#1995)

* chore: remove unused transaction in store

* chore: update
This commit is contained in:
boojack 2023-07-20 23:15:56 +08:00 committed by GitHub
parent c8961ad489
commit 4c33d8d762
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 680 additions and 1266 deletions

View file

@ -18,13 +18,7 @@ type Activity struct {
}
func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO activity (
creator_id,
type,
@ -34,17 +28,13 @@ func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity
VALUES (?, ?, ?, ?)
RETURNING id, created_ts
`
if err := tx.QueryRowContext(ctx, query, create.CreatorID, create.Type, create.Level, create.Payload).Scan(
if err := s.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type, create.Level, create.Payload).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
activity := create
return activity, nil
}

View file

@ -190,21 +190,15 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st
}
}
tx, err := db.DBInstance.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// upsert the newest version to migration_history
// Upsert the newest version to migration_history.
version := minorVersion + ".0"
if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{
if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
Version: version,
}); err != nil {
return fmt.Errorf("failed to upsert migration history with version: %s, err: %w", version, err)
}
return tx.Commit()
return nil
}
func (db *DB) seed(ctx context.Context) error {

View file

@ -2,7 +2,6 @@ package db
import (
"context"
"database/sql"
"strings"
)
@ -20,40 +19,6 @@ type MigrationHistoryFind struct {
}
func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) {
tx, err := db.DBInstance.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := findMigrationHistoryList(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
tx, err := db.DBInstance.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
migrationHistory, err := upsertMigrationHistory(ctx, tx, upsert)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return migrationHistory, nil
}
func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHistoryFind) ([]*MigrationHistory, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Version; v != nil {
@ -69,13 +34,13 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := db.DBInstance.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
migrationHistoryList := make([]*MigrationHistory, 0)
list := make([]*MigrationHistory, 0)
for rows.Next() {
var migrationHistory MigrationHistory
if err := rows.Scan(
@ -85,18 +50,18 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi
return nil, err
}
migrationHistoryList = append(migrationHistoryList, &migrationHistory)
list = append(list, &migrationHistory)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrationHistoryList, nil
return list, nil
}
func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
query := `
func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
stmt := `
INSERT INTO migration_history (
version
)
@ -107,7 +72,7 @@ func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHi
RETURNING version, created_ts
`
var migrationHistory MigrationHistory
if err := tx.QueryRowContext(ctx, query, upsert.Version).Scan(
if err := db.DBInstance.QueryRowContext(ctx, stmt, upsert.Version).Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {

View file

@ -2,7 +2,6 @@ package store
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
@ -63,23 +62,18 @@ type DeleteIdentityProvider struct {
}
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
var configBytes []byte
if create.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(create.Config.OAuth2Config)
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
}
query := `
stmt := `
INSERT INTO idp (
name,
type,
@ -89,9 +83,9 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
VALUES (?, ?, ?, ?)
RETURNING id
`
if err := tx.QueryRowContext(
if err := s.db.QueryRowContext(
ctx,
query,
stmt,
create.Name,
create.Type,
create.IdentifierFilter,
@ -102,166 +96,18 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
identityProvider := create
s.idpCache.Store(identityProvider.ID, identityProvider)
return identityProvider, nil
}
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listIdentityProviders(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
for _, item := range list {
s.idpCache.Store(item.ID, item)
}
return list, nil
}
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
if find.ID != nil {
if cache, ok := s.idpCache.Load(*find.ID); ok {
return cache.(*IdentityProvider), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listIdentityProviders(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
identityProvider := list[0]
s.idpCache.Store(identityProvider.ID, identityProvider)
return identityProvider, nil
}
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.Name; v != nil {
set, args = append(set, "name = ?"), append(args, *v)
}
if v := update.IdentifierFilter; v != nil {
set, args = append(set, "identifier_filter = ?"), append(args, *v)
}
if v := update.Config; v != nil {
var configBytes []byte
if update.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
}
set, args = append(set, "config = ?"), append(args, string(configBytes))
}
args = append(args, update.ID)
query := `
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, name, type, identifier_filter, config
`
var identityProvider IdentityProvider
var identityProviderConfig string
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err
}
if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
if err := tx.Commit(); err != nil {
return nil, err
}
s.idpCache.Store(identityProvider.ID, identityProvider)
return &identityProvider, nil
}
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{"id = ?"}, []any{delete.ID}
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.idpCache.Delete(delete.ID)
return nil
}
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
}
rows, err := tx.QueryContext(ctx, `
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
name,
@ -309,5 +155,99 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr
return nil, err
}
for _, item := range identityProviders {
s.idpCache.Store(item.ID, item)
}
return identityProviders, nil
}
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
if find.ID != nil {
if cache, ok := s.idpCache.Load(*find.ID); ok {
return cache.(*IdentityProvider), nil
}
}
list, err := s.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
identityProvider := list[0]
s.idpCache.Store(identityProvider.ID, identityProvider)
return identityProvider, nil
}
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
set, args := []string{}, []any{}
if v := update.Name; v != nil {
set, args = append(set, "name = ?"), append(args, *v)
}
if v := update.IdentifierFilter; v != nil {
set, args = append(set, "identifier_filter = ?"), append(args, *v)
}
if v := update.Config; v != nil {
var configBytes []byte
if update.Type == IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(update.Type))
}
set, args = append(set, "config = ?"), append(args, string(configBytes))
}
args = append(args, update.ID)
stmt := `
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, name, type, identifier_filter, config
`
var identityProvider IdentityProvider
var identityProviderConfig string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err
}
if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
s.idpCache.Store(identityProvider.ID, identityProvider)
return &identityProvider, nil
}
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
where, args := []string{"id = ?"}, []any{delete.ID}
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
s.idpCache.Delete(delete.ID)
return nil
}

View file

@ -84,17 +84,11 @@ type DeleteMemo struct {
}
func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
if create.CreatedTs == 0 {
create.CreatedTs = time.Now().Unix()
}
query := `
stmt := `
INSERT INTO memo (
creator_id,
created_ts,
@ -104,9 +98,9 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) {
VALUES (?, ?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(
if err := s.db.QueryRowContext(
ctx,
query,
stmt,
create.CreatorID,
create.CreatedTs,
create.Content,
@ -119,155 +113,12 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) {
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
memo := create
return memo, nil
}
func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemos(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemos(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
memo := list[0]
return memo, nil
}
func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.CreatedTs; v != nil {
set, args = append(set, "created_ts = ?"), append(args, *v)
}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Content; v != nil {
set, args = append(set, "content = ?"), append(args, *v)
}
if v := update.Visibility; v != nil {
set, args = append(set, "visibility = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE memo
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
`
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return err
}
err = tx.Commit()
return err
}
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{"id = ?"}, []any{delete.ID}
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if err := s.vacuumImpl(ctx, tx); err != nil {
return err
}
err = tx.Commit()
return err
}
func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
args := make([]any, 0, len(memoIDs))
list := make([]string, 0, len(memoIDs))
for _, memoID := range memoIDs {
args = append(args, memoID)
list = append(list, "?")
}
where := fmt.Sprintf("id in (%s)", strings.Join(list, ","))
query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
visibilityList := make([]Visibility, 0)
for rows.Next() {
var visibility Visibility
if err := rows.Scan(&visibility); err != nil {
return nil, err
}
visibilityList = append(visibilityList, visibility)
}
if err := rows.Err(); err != nil {
return nil, err
}
return visibilityList, nil
}
func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
@ -341,7 +192,7 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error)
}
}
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -407,6 +258,98 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error)
return list, nil
}
func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) {
list, err := s.ListMemos(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
memo := list[0]
return memo, nil
}
func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error {
set, args := []string{}, []any{}
if v := update.CreatedTs; v != nil {
set, args = append(set, "created_ts = ?"), append(args, *v)
}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Content; v != nil {
set, args = append(set, "content = ?"), append(args, *v)
}
if v := update.Visibility; v != nil {
set, args = append(set, "visibility = ?"), append(args, *v)
}
args = append(args, update.ID)
stmt := `
UPDATE memo
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
`
if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
return nil
}
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
where, args := []string{"id = ?"}, []any{delete.ID}
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := s.Vacuum(ctx); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) {
args := make([]any, 0, len(memoIDs))
list := make([]string, 0, len(memoIDs))
for _, memoID := range memoIDs {
args = append(args, memoID)
list = append(list, "?")
}
where := fmt.Sprintf("id in (%s)", strings.Join(list, ","))
query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
visibilityList := make([]Visibility, 0)
for rows.Next() {
var visibility Visibility
if err := rows.Scan(&visibility); err != nil {
return nil, err
}
visibilityList = append(visibilityList, visibility)
}
if err := rows.Err(); err != nil {
return nil, err
}
return visibilityList, nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

View file

@ -24,13 +24,7 @@ type DeleteMemoOrganizer struct {
}
func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO memo_organizer (
memo_id,
user_id,
@ -41,11 +35,7 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer)
SET
pinned = EXCLUDED.pinned
`
if _, err := tx.ExecContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
if _, err := s.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil {
return nil, err
}
@ -54,12 +44,6 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer)
}
func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if find.MemoID != 0 {
where = append(where, "memo_id = ?")
@ -78,7 +62,7 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (
FROM memo_organizer
WHERE %s
`, strings.Join(where, " AND "))
row := tx.QueryRowContext(ctx, query, args...)
row := s.db.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
return nil, err
}
@ -95,40 +79,21 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return memoOrganizer, nil
}
func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *DeleteMemoOrganizer) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if v := delete.MemoID; v != nil {
where, args = append(where, "memo_id = ?"), append(args, *v)
}
if v := delete.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *v)
}
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
}
return nil
}

View file

@ -32,13 +32,7 @@ type DeleteMemoRelation struct {
}
func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO memo_relation (
memo_id,
related_memo_id,
@ -50,9 +44,9 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*
RETURNING memo_id, related_memo_id, type
`
memoRelation := &MemoRelation{}
if err := tx.QueryRowContext(
if err := s.db.QueryRowContext(
ctx,
query,
stmt,
create.MemoID,
create.RelatedMemoID,
create.Type,
@ -64,88 +58,10 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return memoRelation, nil
}
func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemoRelations(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemoRelations(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list[0], nil
}
func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "type = ?"), append(args, delete.Type)
}
query := `
DELETE FROM memo_relation
WHERE ` + strings.Join(where, " AND ")
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// Prevent lint warning.
return err
}
return nil
}
func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) ([]*MemoRelation, error) {
where, args := []string{"TRUE"}, []any{}
if find.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
@ -157,7 +73,7 @@ func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation)
where, args = append(where, "type = ?"), append(args, find.Type)
}
rows, err := tx.QueryContext(ctx, `
rows, err := s.db.QueryContext(ctx, `
SELECT
memo_id,
related_memo_id,
@ -169,22 +85,61 @@ func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation)
}
defer rows.Close()
memoRelationMessages := []*MemoRelation{}
list := []*MemoRelation{}
for rows.Next() {
memoRelationMessage := &MemoRelation{}
memoRelation := &MemoRelation{}
if err := rows.Scan(
&memoRelationMessage.MemoID,
&memoRelationMessage.RelatedMemoID,
&memoRelationMessage.Type,
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
memoRelationMessages = append(memoRelationMessages, memoRelationMessage)
list = append(list, memoRelation)
}
if err := rows.Err(); err != nil {
return nil, err
}
return memoRelationMessages, nil
return list, nil
}
func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) {
list, err := s.ListMemoRelations(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error {
where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "type = ?"), append(args, delete.Type)
}
stmt := `
DELETE FROM memo_relation
WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {

View file

@ -31,12 +31,6 @@ type DeleteMemoResource struct {
}
func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set := []string{"memo_id", "resource_id"}
args := []any{upsert.MemoID, upsert.ResourceID}
placeholder := []string{"?", "?"}
@ -56,7 +50,7 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour
RETURNING memo_id, resource_id, created_ts, updated_ts
`
memoResource := &MemoResource{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
if err := s.db.QueryRowContext(ctx, query, args...).Scan(
&memoResource.MemoID,
&memoResource.ResourceID,
&memoResource.CreatedTs,
@ -65,86 +59,10 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return memoResource, nil
}
func (s *Store) ListMemoResources(ctx context.Context, find *FindMemoResource) ([]*MemoResource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemoResources(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listMemoResources(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
memoResource := list[0]
return memoResource, nil
}
func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if v := delete.MemoID; v != nil {
where, args = append(where, "memo_id = ?"), append(args, *v)
}
if v := delete.ResourceID; v != nil {
where, args = append(where, "resource_id = ?"), append(args, *v)
}
stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ")
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) ([]*MemoResource, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.MemoID; v != nil {
@ -164,7 +82,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY updated_ts DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -192,6 +110,38 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource)
return list, nil
}
func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) {
list, err := s.ListMemoResources(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
memoResource := list[0]
return memoResource, nil
}
func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error {
where, args := []string{}, []any{}
if v := delete.MemoID; v != nil {
where, args = append(where, "memo_id = ?"), append(args, *v)
}
if v := delete.ResourceID; v != nil {
where, args = append(where, "resource_id = ?"), append(args, *v)
}
stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}
func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

View file

@ -46,13 +46,7 @@ type DeleteResource struct {
}
func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
if err := tx.QueryRowContext(ctx, `
stmt := `
INSERT INTO resource (
filename,
blob,
@ -64,131 +58,26 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource
)
VALUES (?, ?, ?, ?, ?, ?, ?)
RETURNING id, created_ts, updated_ts
`,
create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath,
`
if err := s.db.QueryRowContext(
ctx,
stmt,
create.Filename,
create.Blob,
create.ExternalLink,
create.Type,
create.Size,
create.CreatorID,
create.InternalPath,
).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
resource := create
return resource, nil
}
func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
resources, err := listResources(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return resources, nil
}
func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
resources, err := listResources(ctx, tx, find)
if err != nil {
return nil, err
}
if len(resources) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
return resources[0], nil
}
func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.Filename; v != nil {
set, args = append(set, "filename = ?"), append(args, *v)
}
args = append(args, update.ID)
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"}
query := `
UPDATE resource
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING ` + strings.Join(fields, ", ")
resource := Resource{}
dests := []any{
&resource.ID,
&resource.Filename,
&resource.ExternalLink,
&resource.Type,
&resource.Size,
&resource.CreatorID,
&resource.CreatedTs,
&resource.UpdatedTs,
&resource.InternalPath,
}
if err := tx.QueryRowContext(ctx, query, args...).Scan(dests...); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &resource, nil
}
func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM resource
WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Resource, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
@ -226,7 +115,7 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso
}
}
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -263,6 +152,74 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso
return list, nil
}
func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) {
resources, err := s.ListResources(ctx, find)
if err != nil {
return nil, err
}
if len(resources) == 0 {
return nil, nil
}
return resources[0], nil
}
func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) {
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.Filename; v != nil {
set, args = append(set, "filename = ?"), append(args, *v)
}
args = append(args, update.ID)
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"}
stmt := `
UPDATE resource
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING ` + strings.Join(fields, ", ")
resource := Resource{}
dests := []any{
&resource.ID,
&resource.Filename,
&resource.ExternalLink,
&resource.Type,
&resource.Size,
&resource.CreatorID,
&resource.CreatedTs,
&resource.UpdatedTs,
&resource.InternalPath,
}
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil {
return nil, err
}
return &resource, nil
}
func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error {
stmt := `
DELETE FROM resource
WHERE id = ?
`
result, err := s.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := s.Vacuum(ctx); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

View file

@ -41,13 +41,7 @@ type DeleteShortcut struct {
}
func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO shortcut (
title,
payload,
@ -56,7 +50,7 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut
VALUES (?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
if err := s.db.QueryRowContext(ctx, stmt, create.Title, create.Payload, create.CreatorID).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
@ -65,134 +59,11 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
shortcut := create
return shortcut, nil
}
func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
shortcut := list[0]
return shortcut, nil
}
func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.Title; v != nil {
set, args = append(set, "title = ?"), append(args, *v)
}
if v := update.Payload; v != nil {
set, args = append(set, "payload = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE shortcut
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
`
shortcut := &Shortcut{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcut.ID,
&shortcut.Title,
&shortcut.Payload,
&shortcut.CreatorID,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&shortcut.RowStatus,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return shortcut, nil
}
func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
if _, err := tx.ExecContext(ctx, stmt, args...); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.shortcutCache.Delete(*delete.ID)
return nil
}
func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
@ -205,7 +76,7 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor
where, args = append(where, "title = ?"), append(args, *v)
}
rows, err := tx.QueryContext(ctx, `
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
title,
@ -248,6 +119,78 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor
return list, nil
}
func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) {
list, err := s.ListShortcuts(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
shortcut := list[0]
return shortcut, nil
}
func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) {
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := update.Title; v != nil {
set, args = append(set, "title = ?"), append(args, *v)
}
if v := update.Payload; v != nil {
set, args = append(set, "payload = ?"), append(args, *v)
}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
args = append(args, update.ID)
stmt := `
UPDATE shortcut
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
`
shortcut := &Shortcut{}
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&shortcut.ID,
&shortcut.Title,
&shortcut.Payload,
&shortcut.CreatorID,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&shortcut.RowStatus,
); err != nil {
return nil, err
}
return shortcut, nil
}
func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
s.shortcutCache.Delete(*delete.ID)
return nil
}
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

View file

@ -2,7 +2,6 @@ package store
import (
"context"
"database/sql"
"strings"
)
@ -28,13 +27,7 @@ type DeleteStorage struct {
}
func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO storage (
name,
type,
@ -43,136 +36,23 @@ func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, e
VALUES (?, ?, ?)
RETURNING id
`
if err := tx.QueryRowContext(ctx, query, create.Name, create.Type, create.Config).Scan(
if err := s.db.QueryRowContext(ctx, stmt, create.Name, create.Type, create.Config).Scan(
&create.ID,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
storage := create
return storage, nil
}
func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listStorages(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listStorages(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list[0], nil
}
func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if update.Name != nil {
set = append(set, "name = ?")
args = append(args, *update.Name)
}
if update.Config != nil {
set = append(set, "config = ?")
args = append(args, *update.Config)
}
args = append(args, update.ID)
query := `
UPDATE storage
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING
id,
name,
type,
config
`
storage := &Storage{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&storage.ID,
&storage.Name,
&storage.Type,
&storage.Config,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return storage, nil
}
func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
query := `
DELETE FROM storage
WHERE id = ?
`
if _, err := tx.ExecContext(ctx, query, delete.ID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
}
return nil
}
func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storage, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = ?"), append(args, *find.ID)
}
rows, err := tx.QueryContext(ctx, `
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
name,
@ -208,3 +88,65 @@ func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storag
return list, nil
}
func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) {
list, err := s.ListStorages(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) {
set, args := []string{}, []any{}
if update.Name != nil {
set = append(set, "name = ?")
args = append(args, *update.Name)
}
if update.Config != nil {
set = append(set, "config = ?")
args = append(args, *update.Config)
}
args = append(args, update.ID)
stmt := `
UPDATE storage
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING
id,
name,
type,
config
`
storage := &Storage{}
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&storage.ID,
&storage.Name,
&storage.Type,
&storage.Config,
); err != nil {
return nil, err
}
return storage, nil
}
func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error {
stmt := `
DELETE FROM storage
WHERE id = ?
`
result, err := s.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View file

@ -2,7 +2,6 @@ package store
import (
"context"
"database/sql"
"strings"
)
@ -17,13 +16,7 @@ type FindSystemSetting struct {
}
func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO system_setting (
name, value, description
)
@ -33,11 +26,7 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting)
value = EXCLUDED.value,
description = EXCLUDED.description
`
if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
@ -46,68 +35,6 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting)
}
func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listSystemSettings(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
for _, systemSettingMessage := range list {
s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
}
return list, nil
}
func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (*SystemSetting, error) {
if find.Name != "" {
if cache, ok := s.systemSettingCache.Load(find.Name); ok {
return cache.(*SystemSetting), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listSystemSettings(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
systemSettingMessage := list[0]
s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
return systemSettingMessage, nil
}
func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string {
if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{
Name: settingName,
}); err == nil && setting != nil {
return setting.Value
}
return defaultValue
}
func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Name != "" {
where, args = append(where, "name = ?"), append(args, find.Name)
@ -121,7 +48,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
FROM system_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -144,5 +71,38 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
return nil, err
}
for _, systemSettingMessage := range list {
s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
}
return list, nil
}
func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (*SystemSetting, error) {
if find.Name != "" {
if cache, ok := s.systemSettingCache.Load(find.Name); ok {
return cache.(*SystemSetting), nil
}
}
list, err := s.ListSystemSettings(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
systemSettingMessage := list[0]
s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
return systemSettingMessage, nil
}
func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string {
if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{
Name: settingName,
}); err == nil && setting != nil {
return setting.Value
}
return defaultValue
}

View file

@ -3,7 +3,6 @@ package store
import (
"context"
"database/sql"
"fmt"
"strings"
)
@ -22,13 +21,7 @@ type DeleteTag struct {
}
func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO tag (
name, creator_id
)
@ -37,11 +30,7 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) {
SET
name = EXCLUDED.name
`
if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.CreatorID); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID); err != nil {
return nil, err
}
@ -50,12 +39,6 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) {
}
func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
where, args := []string{"creator_id = ?"}, []any{find.CreatorID}
query := `
SELECT
@ -65,7 +48,7 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) {
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY name ASC
`
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -88,37 +71,19 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID}
query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, query, args...)
stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
result, err := s.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("tag not found")
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}

View file

@ -2,8 +2,6 @@ package store
import (
"context"
"database/sql"
"errors"
"strings"
)
@ -79,13 +77,7 @@ type DeleteUser struct {
}
func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO user (
username,
role,
@ -97,7 +89,9 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, avatar_url, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query,
if err := s.db.QueryRowContext(
ctx,
stmt,
create.Username,
create.Role,
create.Email,
@ -113,9 +107,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
user := create
s.userCache.Store(user.ID, user)
@ -123,12 +114,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
}
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
@ -163,7 +148,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
`
user := &User{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
if err := s.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
@ -179,100 +164,11 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
s.userCache.Store(user.ID, user)
return user, nil
}
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
for _, user := range list {
s.userCache.Store(user.ID, user)
}
return list, nil
}
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
if find.ID != nil {
if cache, ok := s.userCache.Load(*find.ID); ok {
return cache.(*User), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
user := list[0]
s.userCache.Store(user.ID, user)
return user, nil
}
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
result, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("user not found")
}
if err := s.vacuumImpl(ctx, tx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.userCache.Delete(delete.ID)
return nil
}
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
@ -311,7 +207,7 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -342,5 +238,46 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error)
return nil, err
}
for _, user := range list {
s.userCache.Store(user.ID, user)
}
return list, nil
}
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
if find.ID != nil {
if cache, ok := s.userCache.Load(*find.ID); ok {
return cache.(*User), nil
}
}
list, err := s.ListUsers(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
user := list[0]
s.userCache.Store(user.ID, user)
return user, nil
}
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
result, err := s.db.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
if err := s.Vacuum(ctx); err != nil {
// Prevent linter warning.
return err
}
s.userCache.Delete(delete.ID)
return nil
}

View file

@ -18,13 +18,7 @@ type FindUserSetting struct {
}
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
@ -32,11 +26,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
if _, err := tx.ExecContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
if _, err := s.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value); err != nil {
return nil, err
}
@ -46,59 +36,6 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us
}
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
userSettingList, err := listUserSettings(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
}
return userSettingList, nil
}
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) {
if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok {
return cache.(*UserSetting), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listUserSettings(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
if err := tx.Commit(); err != nil {
return nil, err
}
userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil
}
func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([]*UserSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != "" {
@ -115,7 +52,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
@ -138,9 +75,33 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
return nil, err
}
for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
}
return userSettingList, nil
}
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) {
if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok {
return cache.(*UserSetting), nil
}
}
list, err := s.ListUserSettings(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

47
test/store/store_test.go Normal file
View file

@ -0,0 +1,47 @@
package teststore
import (
"context"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestConcurrentReadWrite(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
const numWorkers = 10
const numIterations = 100
wg := sync.WaitGroup{}
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go func() {
for j := 0; j < numIterations; j++ {
_, err := ts.CreateMemo(ctx, &store.Memo{
CreatorID: user.ID,
Content: fmt.Sprintf("test_content_%d", i),
Visibility: store.Public,
})
require.NoError(t, err)
}
}()
go func() {
_, err := ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
})
require.NoError(t, err)
wg.Done()
}()
}
wg.Wait()
}