diff --git a/store/activity.go b/store/activity.go index 1b7a12a5..639bb43e 100644 --- a/store/activity.go +++ b/store/activity.go @@ -18,13 +18,7 @@ type Activity struct { } func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO activity ( creator_id, type, @@ -34,17 +28,13 @@ func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity VALUES (?, ?, ?, ?) RETURNING id, created_ts ` - if err := tx.QueryRowContext(ctx, query, create.CreatorID, create.Type, create.Level, create.Payload).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type, create.Level, create.Payload).Scan( &create.ID, &create.CreatedTs, ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - activity := create return activity, nil } diff --git a/store/db/db.go b/store/db/db.go index df0cdc53..b57f59cc 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -190,21 +190,15 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st } } - tx, err := db.DBInstance.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // upsert the newest version to migration_history + // Upsert the newest version to migration_history. version := minorVersion + ".0" - if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{ + if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ Version: version, }); err != nil { return fmt.Errorf("failed to upsert migration history with version: %s, err: %w", version, err) } - return tx.Commit() + return nil } func (db *DB) seed(ctx context.Context) error { diff --git a/store/db/migration_history.go b/store/db/migration_history.go index cbda3445..e4b897e6 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "strings" ) @@ -20,40 +19,6 @@ type MigrationHistoryFind struct { } func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := findMigrationHistoryList(ctx, tx, find) - if err != nil { - return nil, err - } - - return list, nil -} - -func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - migrationHistory, err := upsertMigrationHistory(ctx, tx, upsert) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return migrationHistory, nil -} - -func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHistoryFind) ([]*MigrationHistory, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Version; v != nil { @@ -69,13 +34,13 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi WHERE ` + strings.Join(where, " AND ") + ` ORDER BY created_ts DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := db.DBInstance.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - migrationHistoryList := make([]*MigrationHistory, 0) + list := make([]*MigrationHistory, 0) for rows.Next() { var migrationHistory MigrationHistory if err := rows.Scan( @@ -85,18 +50,18 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi return nil, err } - migrationHistoryList = append(migrationHistoryList, &migrationHistory) + list = append(list, &migrationHistory) } if err := rows.Err(); err != nil { return nil, err } - return migrationHistoryList, nil + return list, nil } -func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - query := ` +func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { + stmt := ` INSERT INTO migration_history ( version ) @@ -107,7 +72,7 @@ func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHi RETURNING version, created_ts ` var migrationHistory MigrationHistory - if err := tx.QueryRowContext(ctx, query, upsert.Version).Scan( + if err := db.DBInstance.QueryRowContext(ctx, stmt, upsert.Version).Scan( &migrationHistory.Version, &migrationHistory.CreatedTs, ); err != nil { diff --git a/store/idp.go b/store/idp.go index 4f22a871..e94eab73 100644 --- a/store/idp.go +++ b/store/idp.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "encoding/json" "fmt" "strings" @@ -63,23 +62,18 @@ type DeleteIdentityProvider struct { } func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - var configBytes []byte if create.Type == IdentityProviderOAuth2Type { - configBytes, err = json.Marshal(create.Config.OAuth2Config) + bytes, err := json.Marshal(create.Config.OAuth2Config) if err != nil { return nil, err } + configBytes = bytes } else { return nil, fmt.Errorf("unsupported idp type %s", string(create.Type)) } - query := ` + stmt := ` INSERT INTO idp ( name, type, @@ -89,9 +83,9 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv VALUES (?, ?, ?, ?) RETURNING id ` - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.Name, create.Type, create.IdentifierFilter, @@ -102,166 +96,18 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - identityProvider := create s.idpCache.Store(identityProvider.ID, identityProvider) return identityProvider, nil } func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listIdentityProviders(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, item := range list { - s.idpCache.Store(item.ID, item) - } - return list, nil -} - -func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) { - if find.ID != nil { - if cache, ok := s.idpCache.Load(*find.ID); ok { - return cache.(*IdentityProvider), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listIdentityProviders(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - identityProvider := list[0] - s.idpCache.Store(identityProvider.ID, identityProvider) - return identityProvider, nil -} - -func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - if v := update.Name; v != nil { - set, args = append(set, "name = ?"), append(args, *v) - } - if v := update.IdentifierFilter; v != nil { - set, args = append(set, "identifier_filter = ?"), append(args, *v) - } - if v := update.Config; v != nil { - var configBytes []byte - if update.Type == IdentityProviderOAuth2Type { - configBytes, err = json.Marshal(update.Config.OAuth2Config) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("unsupported idp type %s", string(update.Type)) - } - set, args = append(set, "config = ?"), append(args, string(configBytes)) - } - args = append(args, update.ID) - - query := ` - UPDATE idp - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING id, name, type, identifier_filter, config - ` - var identityProvider IdentityProvider - var identityProviderConfig string - if err := tx.QueryRowContext(ctx, query, args...).Scan( - &identityProvider.ID, - &identityProvider.Name, - &identityProvider.Type, - &identityProvider.IdentifierFilter, - &identityProviderConfig, - ); err != nil { - return nil, err - } - - if identityProvider.Type == IdentityProviderOAuth2Type { - oauth2Config := &IdentityProviderOAuth2Config{} - if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { - return nil, err - } - identityProvider.Config = &IdentityProviderConfig{ - OAuth2Config: oauth2Config, - } - } else { - return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - s.idpCache.Store(identityProvider.ID, identityProvider) - return &identityProvider, nil -} - -func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{"id = ?"}, []any{delete.ID} - stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if _, err = result.RowsAffected(); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return err - } - - s.idpCache.Delete(delete.ID) - return nil -} - -func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) } - rows, err := tx.QueryContext(ctx, ` + rows, err := s.db.QueryContext(ctx, ` SELECT id, name, @@ -309,5 +155,99 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr return nil, err } + for _, item := range identityProviders { + s.idpCache.Store(item.ID, item) + } return identityProviders, nil } + +func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) { + if find.ID != nil { + if cache, ok := s.idpCache.Load(*find.ID); ok { + return cache.(*IdentityProvider), nil + } + } + + list, err := s.ListIdentityProviders(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + identityProvider := list[0] + s.idpCache.Store(identityProvider.ID, identityProvider) + return identityProvider, nil +} + +func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { + set, args := []string{}, []any{} + if v := update.Name; v != nil { + set, args = append(set, "name = ?"), append(args, *v) + } + if v := update.IdentifierFilter; v != nil { + set, args = append(set, "identifier_filter = ?"), append(args, *v) + } + if v := update.Config; v != nil { + var configBytes []byte + if update.Type == IdentityProviderOAuth2Type { + bytes, err := json.Marshal(update.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(update.Type)) + } + set, args = append(set, "config = ?"), append(args, string(configBytes)) + } + args = append(args, update.ID) + + stmt := ` + UPDATE idp + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, name, type, identifier_filter, config + ` + var identityProvider IdentityProvider + var identityProviderConfig string + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == IdentityProviderOAuth2Type { + oauth2Config := &IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + + s.idpCache.Store(identityProvider.ID, identityProvider) + return &identityProvider, nil +} + +func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + s.idpCache.Delete(delete.ID) + return nil +} diff --git a/store/memo.go b/store/memo.go index fba079c3..83c7b0f9 100644 --- a/store/memo.go +++ b/store/memo.go @@ -84,17 +84,11 @@ type DeleteMemo struct { } func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - if create.CreatedTs == 0 { create.CreatedTs = time.Now().Unix() } - query := ` + stmt := ` INSERT INTO memo ( creator_id, created_ts, @@ -104,9 +98,9 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { VALUES (?, ?, ?, ?) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.CreatorID, create.CreatedTs, create.Content, @@ -119,155 +113,12 @@ func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } memo := create return memo, nil } func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemos(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemos(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - memo := list[0] - return memo, nil -} - -func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - if v := update.CreatedTs; v != nil { - set, args = append(set, "created_ts = ?"), append(args, *v) - } - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = ?"), append(args, *v) - } - if v := update.Content; v != nil { - set, args = append(set, "content = ?"), append(args, *v) - } - if v := update.Visibility; v != nil { - set, args = append(set, "visibility = ?"), append(args, *v) - } - args = append(args, update.ID) - - query := ` - UPDATE memo - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - ` - if _, err := tx.ExecContext(ctx, query, args...); err != nil { - return err - } - err = tx.Commit() - return err -} - -func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{"id = ?"}, []any{delete.ID} - stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if err := s.vacuumImpl(ctx, tx); err != nil { - return err - } - err = tx.Commit() - return err -} - -func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - args := make([]any, 0, len(memoIDs)) - list := make([]string, 0, len(memoIDs)) - for _, memoID := range memoIDs { - args = append(args, memoID) - list = append(list, "?") - } - - where := fmt.Sprintf("id in (%s)", strings.Join(list, ",")) - - query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where - - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - visibilityList := make([]Visibility, 0) - for rows.Next() { - var visibility Visibility - if err := rows.Scan(&visibility); err != nil { - return nil, err - } - visibilityList = append(visibilityList, visibility) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return visibilityList, nil -} - -func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -341,7 +192,7 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) } } - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -407,6 +258,98 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) return list, nil } +func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) { + list, err := s.ListMemos(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memo := list[0] + return memo, nil +} + +func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { + set, args := []string{}, []any{} + if v := update.CreatedTs; v != nil { + set, args = append(set, "created_ts = ?"), append(args, *v) + } + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Content; v != nil { + set, args = append(set, "content = ?"), append(args, *v) + } + if v := update.Visibility; v != nil { + set, args = append(set, "visibility = ?"), append(args, *v) + } + args = append(args, update.ID) + + stmt := ` + UPDATE memo + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil { + return err + } + return nil +} + +func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + return nil +} + +func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) { + args := make([]any, 0, len(memoIDs)) + list := make([]string, 0, len(memoIDs)) + for _, memoID := range memoIDs { + args = append(args, memoID) + list = append(list, "?") + } + + where := fmt.Sprintf("id in (%s)", strings.Join(list, ",")) + query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + visibilityList := make([]Visibility, 0) + for rows.Next() { + var visibility Visibility + if err := rows.Scan(&visibility); err != nil { + return nil, err + } + visibilityList = append(visibilityList, visibility) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return visibilityList, nil +} + func vacuumMemo(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/memo_organizer.go b/store/memo_organizer.go index 1047f30c..910440a8 100644 --- a/store/memo_organizer.go +++ b/store/memo_organizer.go @@ -24,13 +24,7 @@ type DeleteMemoOrganizer struct { } func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO memo_organizer ( memo_id, user_id, @@ -41,11 +35,7 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) SET pinned = EXCLUDED.pinned ` - if _, err := tx.ExecContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil { return nil, err } @@ -54,12 +44,6 @@ func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) } func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - where, args := []string{}, []any{} if find.MemoID != 0 { where = append(where, "memo_id = ?") @@ -78,7 +62,7 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) ( FROM memo_organizer WHERE %s `, strings.Join(where, " AND ")) - row := tx.QueryRowContext(ctx, query, args...) + row := s.db.QueryRowContext(ctx, query, args...) if err := row.Err(); err != nil { return nil, err } @@ -95,40 +79,21 @@ func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) ( return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoOrganizer, nil } func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *DeleteMemoOrganizer) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{}, []any{} - if v := delete.MemoID; v != nil { where, args = append(where, "memo_id = ?"), append(args, *v) } if v := delete.UserID; v != nil { where, args = append(where, "user_id = ?"), append(args, *v) } - stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { + if _, err := s.db.ExecContext(ctx, stmt, args...); err != nil { return err } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - return nil } diff --git a/store/memo_relation.go b/store/memo_relation.go index 3230b54a..d32a7f35 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -32,13 +32,7 @@ type DeleteMemoRelation struct { } func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO memo_relation ( memo_id, related_memo_id, @@ -50,9 +44,9 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* RETURNING memo_id, related_memo_id, type ` memoRelation := &MemoRelation{} - if err := tx.QueryRowContext( + if err := s.db.QueryRowContext( ctx, - query, + stmt, create.MemoID, create.RelatedMemoID, create.Type, @@ -64,88 +58,10 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoRelation, nil } func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoRelations(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoRelations(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list[0], nil -} - -func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{"TRUE"}, []any{} - if delete.MemoID != nil { - where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) - } - if delete.RelatedMemoID != nil { - where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) - } - if delete.Type != nil { - where, args = append(where, "type = ?"), append(args, delete.Type) - } - - query := ` - DELETE FROM memo_relation - WHERE ` + strings.Join(where, " AND ") - if _, err := tx.ExecContext(ctx, query, args...); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent lint warning. - return err - } - return nil -} - -func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) ([]*MemoRelation, error) { where, args := []string{"TRUE"}, []any{} if find.MemoID != nil { where, args = append(where, "memo_id = ?"), append(args, find.MemoID) @@ -157,7 +73,7 @@ func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) where, args = append(where, "type = ?"), append(args, find.Type) } - rows, err := tx.QueryContext(ctx, ` + rows, err := s.db.QueryContext(ctx, ` SELECT memo_id, related_memo_id, @@ -169,22 +85,61 @@ func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) } defer rows.Close() - memoRelationMessages := []*MemoRelation{} + list := []*MemoRelation{} for rows.Next() { - memoRelationMessage := &MemoRelation{} + memoRelation := &MemoRelation{} if err := rows.Scan( - &memoRelationMessage.MemoID, - &memoRelationMessage.RelatedMemoID, - &memoRelationMessage.Type, + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, ); err != nil { return nil, err } - memoRelationMessages = append(memoRelationMessages, memoRelationMessage) + list = append(list, memoRelation) } + if err := rows.Err(); err != nil { return nil, err } - return memoRelationMessages, nil + + return list, nil +} + +func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) { + list, err := s.ListMemoRelations(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + return list[0], nil +} + +func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error { + where, args := []string{"TRUE"}, []any{} + if delete.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) + } + if delete.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) + } + if delete.Type != nil { + where, args = append(where, "type = ?"), append(args, delete.Type) + } + stmt := ` + DELETE FROM memo_relation + WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil } func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { diff --git a/store/memo_resource.go b/store/memo_resource.go index 41024b37..80d993d0 100644 --- a/store/memo_resource.go +++ b/store/memo_resource.go @@ -31,12 +31,6 @@ type DeleteMemoResource struct { } func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set := []string{"memo_id", "resource_id"} args := []any{upsert.MemoID, upsert.ResourceID} placeholder := []string{"?", "?"} @@ -56,7 +50,7 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour RETURNING memo_id, resource_id, created_ts, updated_ts ` memoResource := &MemoResource{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, query, args...).Scan( &memoResource.MemoID, &memoResource.ResourceID, &memoResource.CreatedTs, @@ -65,86 +59,10 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return memoResource, nil } func (s *Store) ListMemoResources(ctx context.Context, find *FindMemoResource) ([]*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listMemoResources(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - memoResource := list[0] - return memoResource, nil -} - -func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{}, []any{} - - if v := delete.MemoID; v != nil { - where, args = append(where, "memo_id = ?"), append(args, *v) - } - if v := delete.ResourceID; v != nil { - where, args = append(where, "resource_id = ?"), append(args, *v) - } - - stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") - _, err = tx.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - - return nil -} - -func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) ([]*MemoResource, error) { where, args := []string{"1 = 1"}, []any{} if v := find.MemoID; v != nil { @@ -164,7 +82,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) WHERE ` + strings.Join(where, " AND ") + ` ORDER BY updated_ts DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -192,6 +110,38 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) return list, nil } +func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) { + list, err := s.ListMemoResources(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memoResource := list[0] + return memoResource, nil +} + +func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { + where, args := []string{}, []any{} + if v := delete.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if v := delete.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} + func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/resource.go b/store/resource.go index 6fbfebd5..7700124c 100644 --- a/store/resource.go +++ b/store/resource.go @@ -46,13 +46,7 @@ type DeleteResource struct { } func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - if err := tx.QueryRowContext(ctx, ` + stmt := ` INSERT INTO resource ( filename, blob, @@ -64,131 +58,26 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource ) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id, created_ts, updated_ts - `, - create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, + ` + if err := s.db.QueryRowContext( + ctx, + stmt, + create.Filename, + create.Blob, + create.ExternalLink, + create.Type, + create.Size, + create.CreatorID, + create.InternalPath, ).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - resource := create return resource, nil } func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - resources, err := listResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return resources, nil -} - -func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - resources, err := listResources(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(resources) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return resources[0], nil -} - -func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.Filename; v != nil { - set, args = append(set, "filename = ?"), append(args, *v) - } - - args = append(args, update.ID) - fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"} - query := ` - UPDATE resource - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING ` + strings.Join(fields, ", ") - resource := Resource{} - dests := []any{ - &resource.ID, - &resource.Filename, - &resource.ExternalLink, - &resource.Type, - &resource.Size, - &resource.CreatorID, - &resource.CreatedTs, - &resource.UpdatedTs, - &resource.InternalPath, - } - if err := tx.QueryRowContext(ctx, query, args...).Scan(dests...); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return &resource, nil -} - -func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, ` - DELETE FROM resource - WHERE id = ? - `, delete.ID); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - - return nil -} - -func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Resource, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -226,7 +115,7 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso } } - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -263,6 +152,74 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso return list, nil } +func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) { + resources, err := s.ListResources(ctx, find) + if err != nil { + return nil, err + } + + if len(resources) == 0 { + return nil, nil + } + + return resources[0], nil +} + +func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) { + set, args := []string{}, []any{} + + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.Filename; v != nil { + set, args = append(set, "filename = ?"), append(args, *v) + } + + args = append(args, update.ID) + fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"} + stmt := ` + UPDATE resource + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING ` + strings.Join(fields, ", ") + resource := Resource{} + dests := []any{ + &resource.ID, + &resource.Filename, + &resource.ExternalLink, + &resource.Type, + &resource.Size, + &resource.CreatorID, + &resource.CreatedTs, + &resource.UpdatedTs, + &resource.InternalPath, + } + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil { + return nil, err + } + + return &resource, nil +} + +func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { + stmt := ` + DELETE FROM resource + WHERE id = ? + ` + result, err := s.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + return nil +} + func vacuumResource(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/shortcut.go b/store/shortcut.go index 7fb4047f..969e8b96 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -41,13 +41,7 @@ type DeleteShortcut struct { } func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO shortcut ( title, payload, @@ -56,7 +50,7 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut VALUES (?, ?, ?) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.Title, create.Payload, create.CreatorID).Scan( &create.ID, &create.CreatedTs, &create.UpdatedTs, @@ -65,134 +59,11 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - shortcut := create return shortcut, nil } func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listShortcuts(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listShortcuts(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - shortcut := list[0] - return shortcut, nil -} - -func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.Title; v != nil { - set, args = append(set, "title = ?"), append(args, *v) - } - if v := update.Payload; v != nil { - set, args = append(set, "payload = ?"), append(args, *v) - } - if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = ?"), append(args, *v) - } - args = append(args, update.ID) - - query := ` - UPDATE shortcut - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status - ` - shortcut := &Shortcut{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( - &shortcut.ID, - &shortcut.Title, - &shortcut.Payload, - &shortcut.CreatorID, - &shortcut.CreatedTs, - &shortcut.UpdatedTs, - &shortcut.RowStatus, - ); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return shortcut, nil -} - -func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - where, args := []string{}, []any{} - if v := delete.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := delete.CreatorID; v != nil { - where, args = append(where, "creator_id = ?"), append(args, *v) - } - - stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ") - if _, err := tx.ExecContext(ctx, stmt, args...); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return err - } - - s.shortcutCache.Delete(*delete.ID) - return nil -} - -func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -205,7 +76,7 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor where, args = append(where, "title = ?"), append(args, *v) } - rows, err := tx.QueryContext(ctx, ` + rows, err := s.db.QueryContext(ctx, ` SELECT id, title, @@ -248,6 +119,78 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor return list, nil } +func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { + list, err := s.ListShortcuts(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + shortcut := list[0] + return shortcut, nil +} + +func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) { + set, args := []string{}, []any{} + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.Title; v != nil { + set, args = append(set, "title = ?"), append(args, *v) + } + if v := update.Payload; v != nil { + set, args = append(set, "payload = ?"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + args = append(args, update.ID) + + stmt := ` + UPDATE shortcut + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status + ` + shortcut := &Shortcut{} + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( + &shortcut.ID, + &shortcut.Title, + &shortcut.Payload, + &shortcut.CreatorID, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &shortcut.RowStatus, + ); err != nil { + return nil, err + } + + return shortcut, nil +} + +func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { + where, args := []string{}, []any{} + if v := delete.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := delete.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + s.shortcutCache.Delete(*delete.ID) + return nil +} + func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/storage.go b/store/storage.go index d043c3eb..8e8c7802 100644 --- a/store/storage.go +++ b/store/storage.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -28,13 +27,7 @@ type DeleteStorage struct { } func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO storage ( name, type, @@ -43,136 +36,23 @@ func (s *Store) CreateStorage(ctx context.Context, create *Storage) (*Storage, e VALUES (?, ?, ?) RETURNING id ` - if err := tx.QueryRowContext(ctx, query, create.Name, create.Type, create.Config).Scan( + if err := s.db.QueryRowContext(ctx, stmt, create.Name, create.Type, create.Config).Scan( &create.ID, ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - storage := create return storage, nil } func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listStorages(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listStorages(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return list[0], nil -} - -func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - set, args := []string{}, []any{} - if update.Name != nil { - set = append(set, "name = ?") - args = append(args, *update.Name) - } - if update.Config != nil { - set = append(set, "config = ?") - args = append(args, *update.Config) - } - args = append(args, update.ID) - - query := ` - UPDATE storage - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING - id, - name, - type, - config - ` - storage := &Storage{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( - &storage.ID, - &storage.Name, - &storage.Type, - &storage.Config, - ); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return storage, nil -} - -func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - query := ` - DELETE FROM storage - WHERE id = ? - ` - if _, err := tx.ExecContext(ctx, query, delete.ID); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. - return err - } - - return nil -} - -func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storage, error) { where, args := []string{"1 = 1"}, []any{} if find.ID != nil { where, args = append(where, "id = ?"), append(args, *find.ID) } - rows, err := tx.QueryContext(ctx, ` + rows, err := s.db.QueryContext(ctx, ` SELECT id, name, @@ -208,3 +88,65 @@ func listStorages(ctx context.Context, tx *sql.Tx, find *FindStorage) ([]*Storag return list, nil } + +func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, error) { + list, err := s.ListStorages(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + return list[0], nil +} + +func (s *Store) UpdateStorage(ctx context.Context, update *UpdateStorage) (*Storage, error) { + set, args := []string{}, []any{} + if update.Name != nil { + set = append(set, "name = ?") + args = append(args, *update.Name) + } + if update.Config != nil { + set = append(set, "config = ?") + args = append(args, *update.Config) + } + args = append(args, update.ID) + + stmt := ` + UPDATE storage + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING + id, + name, + type, + config + ` + storage := &Storage{} + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, + ); err != nil { + return nil, err + } + + return storage, nil +} + +func (s *Store) DeleteStorage(ctx context.Context, delete *DeleteStorage) error { + stmt := ` + DELETE FROM storage + WHERE id = ? + ` + result, err := s.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + return nil +} diff --git a/store/system_setting.go b/store/system_setting.go index 6c06ff5b..aa94849f 100644 --- a/store/system_setting.go +++ b/store/system_setting.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -17,13 +16,7 @@ type FindSystemSetting struct { } func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO system_setting ( name, value, description ) @@ -33,11 +26,7 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) value = EXCLUDED.value, description = EXCLUDED.description ` - if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.Value, upsert.Description); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil { return nil, err } @@ -46,68 +35,6 @@ func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) } func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listSystemSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, systemSettingMessage := range list { - s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) - } - return list, nil -} - -func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (*SystemSetting, error) { - if find.Name != "" { - if cache, ok := s.systemSettingCache.Load(find.Name); ok { - return cache.(*SystemSetting), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listSystemSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - systemSettingMessage := list[0] - s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) - return systemSettingMessage, nil -} - -func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string { - if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{ - Name: settingName, - }); err == nil && setting != nil { - return setting.Value - } - return defaultValue -} - -func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) { where, args := []string{"1 = 1"}, []any{} if find.Name != "" { where, args = append(where, "name = ?"), append(args, find.Name) @@ -121,7 +48,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting FROM system_setting WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -144,5 +71,38 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting return nil, err } + for _, systemSettingMessage := range list { + s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) + } return list, nil } + +func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (*SystemSetting, error) { + if find.Name != "" { + if cache, ok := s.systemSettingCache.Load(find.Name); ok { + return cache.(*SystemSetting), nil + } + } + + list, err := s.ListSystemSettings(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + systemSettingMessage := list[0] + s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) + return systemSettingMessage, nil +} + +func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string { + if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{ + Name: settingName, + }); err == nil && setting != nil { + return setting.Value + } + return defaultValue +} diff --git a/store/tag.go b/store/tag.go index c6295291..37b0077f 100644 --- a/store/tag.go +++ b/store/tag.go @@ -3,7 +3,6 @@ package store import ( "context" "database/sql" - "fmt" "strings" ) @@ -22,13 +21,7 @@ type DeleteTag struct { } func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO tag ( name, creator_id ) @@ -37,11 +30,7 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { SET name = EXCLUDED.name ` - if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.CreatorID); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID); err != nil { return nil, err } @@ -50,12 +39,6 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *Tag) (*Tag, error) { } func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - where, args := []string{"creator_id = ?"}, []any{find.CreatorID} query := ` SELECT @@ -65,7 +48,7 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { WHERE ` + strings.Join(where, " AND ") + ` ORDER BY name ASC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -88,37 +71,19 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return list, nil } func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} - query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, query, args...) + stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") + result, err := s.db.ExecContext(ctx, stmt, args...) if err != nil { return err } - - rows, _ := result.RowsAffected() - if rows == 0 { - return fmt.Errorf("tag not found") - } - - if err := tx.Commit(); err != nil { - // Prevent linter warning. + if _, err = result.RowsAffected(); err != nil { return err } - return nil } diff --git a/store/user.go b/store/user.go index 13b71ef4..4ce11d70 100644 --- a/store/user.go +++ b/store/user.go @@ -2,8 +2,6 @@ package store import ( "context" - "database/sql" - "errors" "strings" ) @@ -79,13 +77,7 @@ type DeleteUser struct { } func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user ( username, role, @@ -97,7 +89,9 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { VALUES (?, ?, ?, ?, ?, ?) RETURNING id, avatar_url, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, + if err := s.db.QueryRowContext( + ctx, + stmt, create.Username, create.Role, create.Email, @@ -113,9 +107,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { ); err != nil { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } user := create s.userCache.Store(user.ID, user) @@ -123,12 +114,6 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { } func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if v := update.UpdatedTs; v != nil { set, args = append(set, "updated_ts = ?"), append(args, *v) @@ -163,7 +148,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status ` user := &User{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, query, args...).Scan( &user.ID, &user.Username, &user.Role, @@ -179,100 +164,11 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - s.userCache.Store(user.ID, user) return user, nil } func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUsers(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, user := range list { - s.userCache.Store(user.ID, user) - } - return list, nil -} - -func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { - if find.ID != nil { - if cache, ok := s.userCache.Load(*find.ID); ok { - return cache.(*User), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUsers(ctx, tx, find) - if err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - user := list[0] - s.userCache.Store(user.ID, user) - return user, nil -} - -func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - result, err := tx.ExecContext(ctx, ` - DELETE FROM user WHERE id = ? - `, delete.ID) - if err != nil { - return err - } - - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return errors.New("user not found") - } - if err := s.vacuumImpl(ctx, tx); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return err - } - - s.userCache.Delete(delete.ID) - return nil -} - -func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -311,7 +207,7 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) WHERE ` + strings.Join(where, " AND ") + ` ORDER BY created_ts DESC, row_status DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -342,5 +238,46 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) return nil, err } + for _, user := range list { + s.userCache.Store(user.ID, user) + } return list, nil } + +func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { + if find.ID != nil { + if cache, ok := s.userCache.Load(*find.ID); ok { + return cache.(*User), nil + } + } + + list, err := s.ListUsers(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + user := list[0] + s.userCache.Store(user.ID, user) + return user, nil +} + +func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { + result, err := s.db.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + if err := s.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + s.userCache.Delete(delete.ID) + return nil +} diff --git a/store/user_setting.go b/store/user_setting.go index 8fd6c284..fa948167 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -18,13 +18,7 @@ type FindUserSetting struct { } func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user_setting ( user_id, key, value ) @@ -32,11 +26,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us ON CONFLICT(user_id, key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := tx.ExecContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value); err != nil { return nil, err } @@ -46,59 +36,6 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us } func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - userSettingList, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - for _, userSetting := range userSettingList { - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) - } - return userSettingList, nil -} - -func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { - if find.UserID != nil { - if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { - return cache.(*UserSetting), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - userSetting := list[0] - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) - return userSetting, nil -} - -func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([]*UserSetting, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != "" { @@ -115,7 +52,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ value FROM user_setting WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -138,9 +75,33 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ return nil, err } + for _, userSetting := range userSettingList { + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) + } return userSettingList, nil } +func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { + if find.UserID != nil { + if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { + return cache.(*UserSetting), nil + } + } + + list, err := s.ListUserSettings(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + userSetting := list[0] + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) + return userSetting, nil +} + func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/test/store/store_test.go b/test/store/store_test.go new file mode 100644 index 00000000..b82a1ad3 --- /dev/null +++ b/test/store/store_test.go @@ -0,0 +1,47 @@ +package teststore + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/store" +) + +func TestConcurrentReadWrite(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + + const numWorkers = 10 + const numIterations = 100 + + wg := sync.WaitGroup{} + wg.Add(numWorkers) + + for i := 0; i < numWorkers; i++ { + go func() { + for j := 0; j < numIterations; j++ { + _, err := ts.CreateMemo(ctx, &store.Memo{ + CreatorID: user.ID, + Content: fmt.Sprintf("test_content_%d", i), + Visibility: store.Public, + }) + require.NoError(t, err) + } + }() + + go func() { + _, err := ts.ListMemos(ctx, &store.FindMemo{ + CreatorID: &user.ID, + }) + require.NoError(t, err) + wg.Done() + }() + } + + wg.Wait() +}