feat: allow tag filtering and count retrieval via api v1 (#1079)

* fix: frontend url to retrieve bookmark count

* chore: unneeded type in generic

* feat: allow tag filtering and count retrieval

* fix: make styles

* fix: make swagger

* fix: make swag

* tests: refactored gettags tests

* fix: initialise tags empty slice
This commit is contained in:
Felipe Martin 2025-03-12 23:10:50 +01:00 committed by GitHub
parent cdc13edb77
commit 21165aa2e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 734 additions and 236 deletions

View file

@ -417,6 +417,20 @@ const docTemplate = `{
"Tags"
],
"summary": "List tags",
"parameters": [
{
"type": "boolean",
"description": "Include bookmark count for each tag",
"name": "with_bookmark_count",
"in": "query"
},
{
"type": "integer",
"description": "Filter tags by bookmark ID",
"name": "bookmark_id",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",

View file

@ -406,6 +406,20 @@
"Tags"
],
"summary": "List tags",
"parameters": [
{
"type": "boolean",
"description": "Include bookmark count for each tag",
"name": "with_bookmark_count",
"in": "query"
},
{
"type": "integer",
"description": "Filter tags by bookmark ID",
"name": "bookmark_id",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",

View file

@ -441,6 +441,15 @@ paths:
/api/v1/tags:
get:
description: List all tags
parameters:
- description: Include bookmark count for each tag
in: query
name: with_bookmark_count
type: boolean
- description: Filter tags by bookmark ID
in: query
name: bookmark_id
type: integer
produces:
- application/json
responses:

View file

@ -124,7 +124,7 @@ func initShiori(ctx context.Context, cmd *cobra.Command) (*config.Config, *depen
account := model.AccountDTO{
Username: "shiori",
Password: "gopher",
Owner: model.Ptr[bool](true),
Owner: model.Ptr(true),
}
if _, err := dependencies.Domains().Accounts().CreateAccount(cmd.Context(), account); err != nil {

View file

@ -2,12 +2,14 @@ package database
import (
"context"
"database/sql"
"fmt"
"log"
"net/url"
"strings"
"github.com/go-shiori/shiori/internal/model"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
@ -39,11 +41,25 @@ func Connect(ctx context.Context, dbURL string) (model.DB, error) {
}
type dbbase struct {
*sqlx.DB
flavor sqlbuilder.Flavor
reader *sqlx.DB
writer *sqlx.DB
}
func (db *dbbase) Flavor() sqlbuilder.Flavor {
return db.flavor
}
func (db *dbbase) ReaderDB() *sqlx.DB {
return db.reader
}
func (db *dbbase) WriterDB() *sqlx.DB {
return db.writer
}
func (db *dbbase) withTx(ctx context.Context, fn func(tx *sqlx.Tx) error) error {
tx, err := db.BeginTxx(ctx, nil)
tx, err := db.writer.BeginTxx(ctx, nil)
if err != nil {
return errors.WithStack(err)
}
@ -64,3 +80,32 @@ func (db *dbbase) withTx(ctx context.Context, fn func(tx *sqlx.Tx) error) error
return err
}
func (db *dbbase) GetContext(ctx context.Context, dest any, query string, args ...any) error {
return db.reader.GetContext(ctx, dest, query, args...)
}
// Deprecated: Use SelectContext instead.
func (db *dbbase) Select(dest any, query string, args ...any) error {
return db.reader.Select(dest, query, args...)
}
func (db *dbbase) SelectContext(ctx context.Context, dest any, query string, args ...any) error {
return db.reader.SelectContext(ctx, dest, query, args...)
}
func (db *dbbase) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.writer.ExecContext(ctx, query, args...)
}
func (db *dbbase) MustBegin() *sqlx.Tx {
return db.writer.MustBegin()
}
func NewDBBase(reader, writer *sqlx.DB, flavor sqlbuilder.Flavor) dbbase {
return dbbase{
reader: reader,
writer: writer,
flavor: flavor,
}
}

View file

@ -0,0 +1,69 @@
package database
import (
"context"
"database/sql"
"fmt"
"log/slog"
"github.com/go-shiori/shiori/internal/model"
"github.com/huandu/go-sqlbuilder"
)
// GetTags returns a list of tags from the database.
// If opts.WithBookmarkCount is true, the result will include the number of bookmarks for each tag.
// If opts.BookmarkID is not 0, the result will include only the tags for the specified bookmark.
// If opts.OrderBy is set, the result will be ordered by the specified column.
func (db *dbbase) GetTags(ctx context.Context, opts model.DBListTagsOptions) ([]model.TagDTO, error) {
sb := db.Flavor().NewSelectBuilder()
sb.Select("t.id", "t.name")
sb.From("tag t")
// Treat the case where we want the bookmark count and filter by bookmark ID as a special case:
// If we only want one of them, we can use a JOIN and GROUP BY.
// If we want both, we need to use a subquery to get the count of bookmarks for each tag filtered
// by bookmark ID.
if opts.WithBookmarkCount && opts.BookmarkID == 0 {
// Join with bookmark_tag and group by tag ID to get the count of bookmarks for each tag
sb.JoinWithOption(sqlbuilder.LeftJoin, "bookmark_tag bt", "bt.tag_id = t.id")
sb.SelectMore("COUNT(bt.tag_id) AS bookmark_count")
sb.GroupBy("t.id")
} else if opts.BookmarkID > 0 {
// If we want the bookmark count, we need to use a subquery to get the count of bookmarks for each tag
if opts.WithBookmarkCount {
sb.SelectMore(
sb.BuilderAs(
db.Flavor().NewSelectBuilder().Select("COUNT(bt2.tag_id)").From("bookmark_tag bt2").Where("bt2.tag_id = t.id"),
"bookmark_count",
),
)
}
// Join with bookmark_tag and filter by bookmark ID to get the tags for a specific bookmark
sb.JoinWithOption(sqlbuilder.RightJoin, "bookmark_tag bt",
sb.And(
"bt.tag_id = t.id",
sb.Equal("bt.bookmark_id", opts.BookmarkID),
),
)
sb.Where(sb.IsNotNull("t.id"))
}
if opts.OrderBy == model.DBTagOrderByTagName {
sb.OrderBy("t.name")
}
query, args := sb.Build()
query = db.ReaderDB().Rebind(query)
slog.Info("GetTags query", "query", query, "args", args)
tags := []model.TagDTO{}
err := db.ReaderDB().SelectContext(ctx, &tags, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to get tags: %w", err)
}
return tags, nil
}

View file

@ -0,0 +1,244 @@
package database
import (
"context"
"testing"
"github.com/go-shiori/shiori/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testGetTagsFunction tests the GetTags function with various options
func testGetTagsFunction(t *testing.T, db model.DB) {
ctx := context.TODO()
// Create test tags
tags := []model.Tag{
{Name: "golang"},
{Name: "database"},
{Name: "testing"},
{Name: "web"},
}
createdTags, err := db.CreateTags(ctx, tags...)
require.NoError(t, err)
require.Len(t, createdTags, 4)
// Map tag names to IDs for easier reference
tagIDsByName := make(map[string]int)
for _, tag := range createdTags {
tagIDsByName[tag.Name] = tag.ID
}
// Create bookmarks with different tag combinations
bookmarks := []model.BookmarkDTO{
{
URL: "https://golang.org",
Title: "Go Language",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: "golang"}},
{Tag: model.Tag{Name: "web"}},
},
},
{
URL: "https://postgresql.org",
Title: "PostgreSQL",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: "database"}},
},
},
{
URL: "https://sqlite.org",
Title: "SQLite",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: "database"}},
{Tag: model.Tag{Name: "testing"}},
},
},
}
// Save bookmarks
var savedBookmarks []model.BookmarkDTO
for _, bookmark := range bookmarks {
result, err := db.SaveBookmarks(ctx, true, bookmark)
require.NoError(t, err)
require.Len(t, result, 1)
savedBookmarks = append(savedBookmarks, result[0])
}
// Verify test data setup
t.Run("VerifyTestData", func(t *testing.T) {
// Check that all bookmarks were saved with their tags
for i, bookmark := range savedBookmarks {
assert.NotZero(t, bookmark.ID)
assert.Len(t, bookmark.Tags, len(bookmarks[i].Tags))
}
// Verify that the first bookmark has golang and web tags
assert.Len(t, savedBookmarks[0].Tags, 2)
tagNames := []string{savedBookmarks[0].Tags[0].Name, savedBookmarks[0].Tags[1].Name}
assert.Contains(t, tagNames, "golang")
assert.Contains(t, tagNames, "web")
})
// Test 1: Get all tags without any options
t.Run("GetAllTags", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
// Should return all 4 tags
assert.Len(t, fetchedTags, 4)
// Verify all tag names are present
tagNames := make(map[string]bool)
for _, tag := range fetchedTags {
tagNames[tag.Name] = true
}
for _, expectedTag := range tags {
assert.True(t, tagNames[expectedTag.Name], "Tag %s should be present", expectedTag.Name)
}
})
// Test 2: Get tags with bookmark count
t.Run("GetTagsWithBookmarkCount", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
WithBookmarkCount: true,
})
require.NoError(t, err)
// Should return all 4 tags
assert.Len(t, fetchedTags, 4)
// Create a map of tag name to bookmark count
tagCounts := make(map[string]int64)
for _, tag := range fetchedTags {
tagCounts[tag.Name] = tag.BookmarkCount
}
// Verify counts
assert.Equal(t, int64(1), tagCounts["golang"])
assert.Equal(t, int64(2), tagCounts["database"])
assert.Equal(t, int64(1), tagCounts["testing"])
assert.Equal(t, int64(1), tagCounts["web"])
})
// Test 3: Get tags for a specific bookmark
t.Run("GetTagsForBookmark", func(t *testing.T) {
// Get tags for the first bookmark (Go Language with golang and web tags)
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
BookmarkID: savedBookmarks[0].ID,
})
require.NoError(t, err)
// Should return 2 tags
assert.Len(t, fetchedTags, 2)
// Verify tag names
tagNames := make(map[string]bool)
for _, tag := range fetchedTags {
tagNames[tag.Name] = true
}
assert.True(t, tagNames["golang"], "Tag 'golang' should be present")
assert.True(t, tagNames["web"], "Tag 'web' should be present")
})
// Test 4: Get tags for a specific bookmark with bookmark count
t.Run("GetTagsForBookmarkWithCount", func(t *testing.T) {
// Get tags for the third bookmark (SQLite with database and testing tags)
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
BookmarkID: savedBookmarks[2].ID,
WithBookmarkCount: true,
})
require.NoError(t, err)
// Should return 2 tags
assert.Len(t, fetchedTags, 2)
// Create a map of tag name to bookmark count
tagCounts := make(map[string]int64)
for _, tag := range fetchedTags {
tagCounts[tag.Name] = tag.BookmarkCount
}
// Verify counts - database should have 2 bookmarks, testing should have 1
assert.Equal(t, int64(2), tagCounts["database"])
assert.Equal(t, int64(1), tagCounts["testing"])
})
// Test 5: Get tags ordered by name
t.Run("GetTagsOrderedByName", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
OrderBy: model.DBTagOrderByTagName,
})
require.NoError(t, err)
// Should return all 4 tags in alphabetical order
assert.Len(t, fetchedTags, 4)
// Verify order
assert.Equal(t, "database", fetchedTags[0].Name)
assert.Equal(t, "golang", fetchedTags[1].Name)
assert.Equal(t, "testing", fetchedTags[2].Name)
assert.Equal(t, "web", fetchedTags[3].Name)
})
// Test 6: Get tags for a non-existent bookmark
t.Run("GetTagsForNonExistentBookmark", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
BookmarkID: 9999, // Non-existent ID
})
require.NoError(t, err)
// Should return empty result
assert.Empty(t, fetchedTags)
})
// Test 7: Get tags for a bookmark with no tags
t.Run("GetTagsForBookmarkWithNoTags", func(t *testing.T) {
// Create a bookmark with no tags
bookmarkWithNoTags := model.BookmarkDTO{
URL: "https://example.com",
Title: "Example with no tags",
}
result, err := db.SaveBookmarks(ctx, true, bookmarkWithNoTags)
require.NoError(t, err)
require.Len(t, result, 1)
// Get tags for this bookmark
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
BookmarkID: result[0].ID,
})
require.NoError(t, err)
// Should return empty result
assert.Empty(t, fetchedTags)
})
// Test 8: Get tags with combined options (order + count)
t.Run("GetTagsWithCombinedOptions", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{
WithBookmarkCount: true,
OrderBy: model.DBTagOrderByTagName,
})
require.NoError(t, err)
// Should return all 4 tags in alphabetical order with counts
assert.Len(t, fetchedTags, 4)
// Verify order and counts
assert.Equal(t, "database", fetchedTags[0].Name)
assert.Equal(t, int64(2), fetchedTags[0].BookmarkCount)
assert.Equal(t, "golang", fetchedTags[1].Name)
assert.Equal(t, int64(1), fetchedTags[1].BookmarkCount)
assert.Equal(t, "testing", fetchedTags[2].Name)
assert.Equal(t, int64(1), fetchedTags[2].BookmarkCount)
assert.Equal(t, "web", fetchedTags[3].Name)
assert.Equal(t, int64(1), fetchedTags[3].BookmarkCount)
})
}

View file

@ -35,10 +35,11 @@ func testDatabase(t *testing.T, dbFactory testDatabaseFactory) {
"testSaveBookmark": testSaveBookmark,
"testBulkUpdateBookmarkTags": testBulkUpdateBookmarkTags,
// Tags
"testCreateTag": testCreateTag,
"testCreateTags": testCreateTags,
"testGetTags": testGetTags,
"testGetTagsBookmarkCount": testGetTagsBookmarkCount,
"testCreateTag": testCreateTag,
"testCreateTags": testCreateTags,
"testGetTags": testGetTags,
"testGetTagsFunction": testGetTagsFunction,
// "testGetTagsBookmarkCount": testGetTagsBookmarkCount,
"testGetTag": testGetTag,
"testGetTagNotExistent": testGetTagNotExistent,
"testUpdateTag": testUpdateTag,
@ -428,7 +429,7 @@ func testGetBookmarksWithTags(t *testing.T, db model.DB) {
}
t.Run("ensure tags are present", func(t *testing.T) {
tags, err := db.GetTags(ctx)
tags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
assert.Len(t, tags, 4)
})
@ -831,7 +832,7 @@ func testGetTags(t *testing.T, db model.DB) {
require.Len(t, createdTags, 3)
// Fetch all tags
fetchedTags, err := db.GetTags(ctx)
fetchedTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
require.GreaterOrEqual(t, len(fetchedTags), 4) // At least 3 new tags + 1 initial tag
@ -949,107 +950,6 @@ func testDeleteTagNotExistent(t *testing.T, db model.DB) {
assert.ErrorIs(t, err, ErrNotFound, "Error should be ErrNotFound")
}
func testGetTagsBookmarkCount(t *testing.T, db model.DB) {
ctx := context.TODO()
// Create test tags
tags := []model.Tag{
{Name: "tag1-count"},
{Name: "tag2-count"},
}
_, err := db.CreateTags(ctx, model.Tag{Name: "tag3-count"})
require.NoError(t, err)
// Create bookmarks with different tag combinations
bookmark1 := model.BookmarkDTO{
URL: "https://example1.com",
Title: "Example 1",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: tags[0].Name}}, // tag1
{Tag: model.Tag{Name: tags[1].Name}}, // tag2
},
}
bookmark2 := model.BookmarkDTO{
URL: "https://example2.com",
Title: "Example 2",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: tags[0].Name}}, // tag1
},
}
bookmark3 := model.BookmarkDTO{
URL: "https://example3.com",
Title: "Example 3",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: tags[1].Name}}, // tag2
},
}
// Save bookmarks
bookmarks, err := db.SaveBookmarks(ctx, true, bookmark1, bookmark2, bookmark3)
require.NoError(t, err)
t.Run("GetBookmarks", func(t *testing.T) {
result, err := db.GetBookmarks(ctx, model.DBGetBookmarksOptions{
Tags: []string{tags[0].Name},
})
require.NoError(t, err)
require.NotEmpty(t, result)
})
t.Run("GetTag", func(t *testing.T) {
t.Log(bookmarks[0])
tag, exists, err := db.GetTag(ctx, bookmarks[0].Tags[0].ID)
require.NoError(t, err)
require.True(t, exists)
assert.Equal(t, tags[0].Name, tag.Name)
assert.Equal(t, int64(2), tag.BookmarkCount)
})
// Test GetTags
t.Run("GetTags", func(t *testing.T) {
fetchedTags, err := db.GetTags(ctx)
require.NoError(t, err)
require.GreaterOrEqual(t, len(fetchedTags), 3)
// Create a map of tag name to bookmark count
tagCounts := make(map[string]int64)
for _, tag := range fetchedTags {
tagCounts[tag.Name] = tag.BookmarkCount
}
// Verify counts
assert.Equal(t, int64(2), tagCounts["tag1-count"])
assert.Equal(t, int64(2), tagCounts["tag2-count"])
assert.Equal(t, int64(0), tagCounts["tag3-count"])
})
// Test count updates after bookmark deletion
t.Run("CountAfterDeletion", func(t *testing.T) {
// Get the first bookmark that has tag1
bookmarks, err := db.GetBookmarks(ctx, model.DBGetBookmarksOptions{
Tags: []string{tags[0].Name},
})
require.NoError(t, err)
require.NotEmpty(t, bookmarks)
require.NotEmpty(t, bookmarks[0].Tags)
tagID := bookmarks[0].Tags[0].ID
// Delete the first bookmark
err = db.DeleteBookmarks(ctx, bookmarks[0].ID)
require.NoError(t, err)
// Verify updated counts
tag1, exists, err := db.GetTag(ctx, tagID)
require.NoError(t, err)
require.True(t, exists)
assert.Equal(t, int64(1), tag1.BookmarkCount, "tag1-count should have 1 bookmark after deletion")
})
}
func testSaveBookmark(t *testing.T, db model.DB) {
ctx := context.TODO()

View file

@ -87,20 +87,10 @@ func OpenMySQLDatabase(ctx context.Context, connString string) (mysqlDB *MySQLDa
db.SetMaxOpenConns(100)
db.SetConnMaxLifetime(time.Second) // in case mysql client has longer timeout (driver issue #674)
mysqlDB = &MySQLDatabase{dbbase: dbbase{db}}
mysqlDB = &MySQLDatabase{dbbase: NewDBBase(db, db, sqlbuilder.MySQL)}
return mysqlDB, err
}
// WriterDB returns the underlying sqlx.DB object
func (db *MySQLDatabase) WriterDB() *sqlx.DB {
return db.DB
}
// ReaderDB returns the underlying sqlx.DB object
func (db *MySQLDatabase) ReaderDB() *sqlx.DB {
return db.DB
}
// Init initializes the database
func (db *MySQLDatabase) Init(ctx context.Context) error {
return nil
@ -872,27 +862,6 @@ func (db *MySQLDatabase) RenameTag(ctx context.Context, id int, newName string)
return nil
}
// GetTags fetch list of tags and their frequency.
func (db *MySQLDatabase) GetTags(ctx context.Context) ([]model.TagDTO, error) {
sb := sqlbuilder.MySQL.NewSelectBuilder()
sb.Select("t.id", "t.name", "COUNT(bt.tag_id) AS bookmark_count")
sb.From("tag t")
sb.JoinWithOption(sqlbuilder.LeftJoin, "bookmark_tag bt", "bt.tag_id = t.id")
sb.GroupBy("t.id")
sb.OrderBy("t.name")
query, args := sb.Build()
query = db.ReaderDB().Rebind(query)
tags := []model.TagDTO{}
err := db.ReaderDB().SelectContext(ctx, &tags, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to get tags: %w", err)
}
return tags, nil
}
// GetTag fetch a tag by its ID.
func (db *MySQLDatabase) GetTag(ctx context.Context, id int) (model.TagDTO, bool, error) {
sb := sqlbuilder.MySQL.NewSelectBuilder()

View file

@ -52,7 +52,7 @@ func mysqlTestDatabaseFactory(envKey string) testDatabaseFactory {
return nil, err
}
if _, err := db.Exec("USE " + dbname); err != nil {
if _, err := db.ExecContext(ctx, "USE "+dbname); err != nil {
return nil, err
}

View file

@ -89,20 +89,10 @@ func OpenPGDatabase(ctx context.Context, connString string) (pgDB *PGDatabase, e
db.SetMaxOpenConns(100)
db.SetConnMaxLifetime(time.Second)
pgDB = &PGDatabase{dbbase: dbbase{db}}
pgDB = &PGDatabase{dbbase: NewDBBase(db, db, sqlbuilder.PostgreSQL)}
return pgDB, err
}
// WriterDB returns the underlying sqlx.DB object
func (db *PGDatabase) WriterDB() *sqlx.DB {
return db.DB
}
// ReaderDB returns the underlying sqlx.DB object
func (db *PGDatabase) ReaderDB() *sqlx.DB {
return db.DB
}
// Init initializes the database
func (db *PGDatabase) Init(ctx context.Context) error {
return nil
@ -398,7 +388,7 @@ func (db *PGDatabase) GetBookmarks(ctx context.Context, opts model.DBGetBookmark
if err != nil {
return nil, fmt.Errorf("failed to expand query: %v", err)
}
query = db.Rebind(query)
query = db.ReaderDB().Rebind(query)
// Fetch bookmarks
bookmarks := []model.BookmarkDTO{}
@ -408,7 +398,7 @@ func (db *PGDatabase) GetBookmarks(ctx context.Context, opts model.DBGetBookmark
}
// Fetch tags for each bookmarks
stmtGetTags, err := db.PreparexContext(ctx, `SELECT t.id, t.name
stmtGetTags, err := db.ReaderDB().PreparexContext(ctx, `SELECT t.id, t.name
FROM bookmark_tag bt
LEFT JOIN tag t ON bt.tag_id = t.id
WHERE bt.bookmark_id = $1
@ -521,7 +511,7 @@ func (db *PGDatabase) GetBookmarksCount(ctx context.Context, opts model.DBGetBoo
if err != nil {
return 0, errors.WithStack(err)
}
query = db.Rebind(query)
query = db.ReaderDB().Rebind(query)
// Fetch count
var nBookmarks int
@ -899,27 +889,6 @@ func (db *PGDatabase) RenameTag(ctx context.Context, id int, newName string) err
return nil
}
// GetTags fetch list of tags and their frequency.
func (db *PGDatabase) GetTags(ctx context.Context) ([]model.TagDTO, error) {
sb := sqlbuilder.PostgreSQL.NewSelectBuilder()
sb.Select("t.id", "t.name", "COUNT(bt.tag_id) bookmark_count")
sb.From("tag t")
sb.JoinWithOption(sqlbuilder.LeftJoin, "bookmark_tag bt", "bt.tag_id = t.id")
sb.GroupBy("t.id")
sb.OrderBy("t.name")
query, args := sb.Build()
query = db.ReaderDB().Rebind(query)
tags := []model.TagDTO{}
err := db.ReaderDB().SelectContext(ctx, &tags, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to get tags: %w", err)
}
return tags, nil
}
// GetTag fetch a tag by its ID.
func (db *PGDatabase) GetTag(ctx context.Context, id int) (model.TagDTO, bool, error) {
sb := sqlbuilder.NewSelectBuilder()

View file

@ -25,7 +25,7 @@ func postgresqlTestDatabaseFactory(_ *testing.T, ctx context.Context) (model.DB,
return nil, err
}
_, err = db.Exec("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
_, err = db.ExecContext(ctx, "DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
if err != nil {
return nil, err
}

View file

@ -69,8 +69,7 @@ var sqliteMigrations = []migration{
// SQLiteDatabase is implementation of Database interface
// for connecting to SQLite3 database.
type SQLiteDatabase struct {
writer *dbbase
reader *dbbase
dbbase
}
// withTx executes the given function within a transaction.
@ -123,7 +122,7 @@ func (db *SQLiteDatabase) withTxRetry(ctx context.Context, fn func(tx *sqlx.Tx)
// Init sets up the SQLite database with optimal settings for both reader and writer connections
func (db *SQLiteDatabase) Init(ctx context.Context) error {
// Initialize both connections with appropriate settings
for _, conn := range []*dbbase{db.writer, db.reader} {
for _, conn := range []*sqlx.DB{db.WriterDB(), db.ReaderDB()} {
// Reuse connections for up to one hour
conn.SetConnMaxLifetime(time.Hour)
@ -168,12 +167,12 @@ type bookmarkContent struct {
// DBX returns the underlying sqlx.DB object for writes
func (db *SQLiteDatabase) WriterDB() *sqlx.DB {
return db.writer.DB
return db.dbbase.WriterDB()
}
// ReaderDBx returns the underlying sqlx.DB object for reading
func (db *SQLiteDatabase) ReaderDB() *sqlx.DB {
return db.reader.DB
return db.dbbase.ReaderDB()
}
// Migrate runs migrations for this database engine
@ -1050,27 +1049,6 @@ func (db *SQLiteDatabase) RenameTag(ctx context.Context, id int, newName string)
return nil
}
// GetTags fetch list of tags and their frequency.
func (db *SQLiteDatabase) GetTags(ctx context.Context) ([]model.TagDTO, error) {
sb := sqlbuilder.SQLite.NewSelectBuilder()
sb.Select("t.id", "t.name", "COUNT(bt.tag_id) AS bookmark_count")
sb.From("tag t")
sb.JoinWithOption(sqlbuilder.LeftJoin, "bookmark_tag bt", "bt.tag_id = t.id")
sb.GroupBy("t.id")
sb.OrderBy("t.name")
query, args := sb.Build()
query = db.ReaderDB().Rebind(query)
tags := []model.TagDTO{}
err := db.ReaderDB().SelectContext(ctx, &tags, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to get tags: %w", err)
}
return tags, nil
}
// GetTag fetch a tag by its ID.
func (db *SQLiteDatabase) GetTag(ctx context.Context, id int) (model.TagDTO, bool, error) {
sb := sqlbuilder.SQLite.NewSelectBuilder()

View file

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
@ -26,8 +27,11 @@ func OpenSQLiteDatabase(ctx context.Context, databasePath string) (sqliteDB *SQL
}
sqliteDB = &SQLiteDatabase{
writer: &dbbase{rwDB},
reader: &dbbase{rDB},
dbbase: dbbase{
writer: rwDB,
reader: rDB,
flavor: sqlbuilder.SQLite,
},
}
if err := sqliteDB.Init(ctx); err != nil {

View file

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
_ "git.sr.ht/~emersion/go-sqlite3-fts5"
@ -27,8 +28,11 @@ func OpenSQLiteDatabase(ctx context.Context, databasePath string) (sqliteDB *SQL
}
sqliteDB = &SQLiteDatabase{
writer: &dbbase{rwDB},
reader: &dbbase{rDB},
dbbase: dbbase{
writer: rwDB,
reader: rDB,
flavor: sqlbuilder.SQLite,
},
}
if err := sqliteDB.Init(ctx); err != nil {

View file

@ -16,8 +16,8 @@ func NewTagsDomain(deps model.Dependencies) model.TagsDomain {
return &tagsDomain{deps: deps}
}
func (d *tagsDomain) ListTags(ctx context.Context) ([]model.TagDTO, error) {
tags, err := d.deps.Database().GetTags(ctx)
func (d *tagsDomain) ListTags(ctx context.Context, opts model.ListTagsOptions) ([]model.TagDTO, error) {
tags, err := d.deps.Database().GetTags(ctx, model.DBListTagsOptions(opts))
if err != nil {
return nil, err
}

View file

@ -35,7 +35,7 @@ func TestTagsDomain(t *testing.T) {
require.Len(t, createdTags, 2)
// List the tags
tags, err := tagsDomain.ListTags(ctx)
tags, err := tagsDomain.ListTags(ctx, model.ListTagsOptions{})
require.NoError(t, err)
require.Len(t, tags, 2)
@ -44,6 +44,125 @@ func TestTagsDomain(t *testing.T) {
assert.Equal(t, "tag2", tags[1].Name)
})
// Test ListTags with WithBookmarkCount
t.Run("ListTags_WithBookmarkCount", func(t *testing.T) {
// Create a test tag
tag := model.Tag{Name: "tag-with-count"}
createdTags, err := db.CreateTags(ctx, tag)
require.NoError(t, err)
require.Len(t, createdTags, 1)
// Create a bookmark with this tag
bookmark := model.BookmarkDTO{
URL: "https://example-count.com",
Title: "Example for Count",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: tag.Name}},
},
}
_, err = db.SaveBookmarks(ctx, true, bookmark)
require.NoError(t, err)
// List tags with bookmark count
tags, err := tagsDomain.ListTags(ctx, model.ListTagsOptions{
WithBookmarkCount: true,
})
require.NoError(t, err)
require.NotEmpty(t, tags)
// Find our test tag and verify it has a bookmark count
var foundTag model.TagDTO
for _, t := range tags {
if t.Name == tag.Name {
foundTag = t
break
}
}
require.NotZero(t, foundTag.ID, "Should find the test tag")
assert.Equal(t, int64(1), foundTag.BookmarkCount, "Tag should have a bookmark count of 1")
})
// Test ListTags with BookmarkID
t.Run("ListTags_WithBookmarkID", func(t *testing.T) {
// Create test tags
testTags := []model.Tag{
{Name: "tag-for-bookmark1"},
{Name: "tag-for-bookmark2"},
}
createdTags, err := db.CreateTags(ctx, testTags...)
require.NoError(t, err)
require.Len(t, createdTags, 2)
// Create bookmarks with different tags
bookmark1 := model.BookmarkDTO{
URL: "https://example-bookmark1.com",
Title: "Example Bookmark 1",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: testTags[0].Name}},
},
}
bookmark2 := model.BookmarkDTO{
URL: "https://example-bookmark2.com",
Title: "Example Bookmark 2",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: testTags[1].Name}},
},
}
savedBookmarks, err := db.SaveBookmarks(ctx, true, bookmark1, bookmark2)
require.NoError(t, err)
require.Len(t, savedBookmarks, 2)
// Get tags for the first bookmark
tags, err := tagsDomain.ListTags(ctx, model.ListTagsOptions{
BookmarkID: savedBookmarks[0].ID,
})
require.NoError(t, err)
require.Len(t, tags, 1, "Should return exactly one tag for the bookmark")
assert.Equal(t, testTags[0].Name, tags[0].Name, "Should return the correct tag for the bookmark")
// Get tags for the second bookmark
tags, err = tagsDomain.ListTags(ctx, model.ListTagsOptions{
BookmarkID: savedBookmarks[1].ID,
})
require.NoError(t, err)
require.Len(t, tags, 1, "Should return exactly one tag for the bookmark")
assert.Equal(t, testTags[1].Name, tags[0].Name, "Should return the correct tag for the bookmark")
})
// Test ListTags with both options
t.Run("ListTags_WithBothOptions", func(t *testing.T) {
// Create a test tag
tag := model.Tag{Name: "tag-with-both-options"}
createdTags, err := db.CreateTags(ctx, tag)
require.NoError(t, err)
require.Len(t, createdTags, 1)
// Create a bookmark with this tag
bookmark := model.BookmarkDTO{
URL: "https://example-both-options.com",
Title: "Example for Both Options",
Tags: []model.TagDTO{
{Tag: model.Tag{Name: tag.Name}},
},
}
savedBookmarks, err := db.SaveBookmarks(ctx, true, bookmark)
require.NoError(t, err)
require.Len(t, savedBookmarks, 1)
// List tags with both options
tags, err := tagsDomain.ListTags(ctx, model.ListTagsOptions{
BookmarkID: savedBookmarks[0].ID,
WithBookmarkCount: true,
})
require.NoError(t, err)
require.Len(t, tags, 1, "Should return exactly one tag")
assert.Equal(t, tag.Name, tags[0].Name, "Should return the correct tag")
assert.Equal(t, int64(1), tags[0].BookmarkCount, "Tag should have a bookmark count of 1")
})
// Test CreateTag
t.Run("CreateTag", func(t *testing.T) {
// Create a new tag
@ -59,9 +178,9 @@ func TestTagsDomain(t *testing.T) {
assert.Greater(t, createdTag.ID, 0, "The created tag should have a valid ID")
// Verify the tag was created in the database
allTags, err := db.GetTags(ctx)
allTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
require.Len(t, allTags, 3) // 2 from previous test + 1 new
require.GreaterOrEqual(t, len(allTags), 1) // At least our new tag
// Find the created tag in the list
var found bool
@ -78,7 +197,7 @@ func TestTagsDomain(t *testing.T) {
// Test GetTag - Success
t.Run("GetTag_Success", func(t *testing.T) {
// Get all tags to find an ID
allTags, err := db.GetTags(ctx)
allTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
require.NotEmpty(t, allTags)
@ -102,7 +221,7 @@ func TestTagsDomain(t *testing.T) {
// Test UpdateTag
t.Run("UpdateTag", func(t *testing.T) {
// Get all tags to find an ID
allTags, err := db.GetTags(ctx)
allTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
require.NotEmpty(t, allTags)
@ -131,7 +250,7 @@ func TestTagsDomain(t *testing.T) {
// Test DeleteTag
t.Run("DeleteTag", func(t *testing.T) {
// Get all tags to find an ID
allTags, err := db.GetTags(ctx)
allTags, err := db.GetTags(ctx, model.DBListTagsOptions{})
require.NoError(t, err)
require.NotEmpty(t, allTags)

View file

@ -15,16 +15,35 @@ import (
// @Tags Tags
// @securityDefinitions.apikey ApiKeyAuth
// @Produce json
// @Success 200 {array} model.TagDTO
// @Failure 403 {object} nil "Authentication required"
// @Failure 500 {object} nil "Internal server error"
// @Param with_bookmark_count query boolean false "Include bookmark count for each tag"
// @Param bookmark_id query integer false "Filter tags by bookmark ID"
// @Success 200 {array} model.TagDTO
// @Failure 403 {object} nil "Authentication required"
// @Failure 500 {object} nil "Internal server error"
// @Router /api/v1/tags [get]
func HandleListTags(deps model.Dependencies, c model.WebContext) {
if err := middleware.RequireLoggedInUser(deps, c); err != nil {
return
}
tags, err := deps.Domains().Tags().ListTags(c.Request().Context())
// Parse query parameters
withBookmarkCount := c.Request().URL.Query().Get("with_bookmark_count") == "true"
var bookmarkID int
if bookmarkIDStr := c.Request().URL.Query().Get("bookmark_id"); bookmarkIDStr != "" {
var err error
bookmarkID, err = strconv.Atoi(bookmarkIDStr)
if err != nil {
response.SendError(c, http.StatusBadRequest, "Invalid bookmark ID", nil)
return
}
}
tags, err := deps.Domains().Tags().ListTags(c.Request().Context(), model.ListTagsOptions{
WithBookmarkCount: withBookmarkCount,
BookmarkID: bookmarkID,
OrderBy: model.DBTagOrderByTagName,
})
if err != nil {
deps.Logger().WithError(err).Error("failed to get tags")
response.SendInternalServerError(c)
@ -75,7 +94,7 @@ func HandleGetTag(deps model.Dependencies, c model.WebContext) {
// @Description Create a new tag
// @Tags Tags
// @securityDefinitions.apikey ApiKeyAuth
// @Accept json
// @Accept json
// @Produce json
// @Param tag body model.TagDTO true "Tag data"
// @Success 201 {object} model.TagDTO
@ -114,9 +133,9 @@ func HandleCreateTag(deps model.Dependencies, c model.WebContext) {
// @Description Update an existing tag
// @Tags Tags
// @securityDefinitions.apikey ApiKeyAuth
// @Accept json
// @Accept json
// @Produce json
// @Param id path int true "Tag ID"
// @Param id path int true "Tag ID"
// @Param tag body model.TagDTO true "Tag data"
// @Success 200 {object} model.TagDTO
// @Failure 400 {object} nil "Invalid request"

View file

@ -40,6 +40,115 @@ func TestHandleListTags(t *testing.T) {
response.AssertOk(t)
response.AssertMessageIsNotEmptyList(t)
})
t.Run("with_bookmark_count parameter", func(t *testing.T) {
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
// Create a test tag
tag := model.Tag{Name: "test-tag-with-count"}
createdTags, err := deps.Database().CreateTags(ctx, tag)
require.NoError(t, err)
require.Len(t, createdTags, 1)
w := testutil.PerformRequest(
deps,
HandleListTags,
"GET",
"/api/v1/tags",
testutil.WithFakeUser(),
testutil.WithRequestQueryParam("with_bookmark_count", "true"),
)
require.Equal(t, http.StatusOK, w.Code)
response, err := testutil.NewTestResponseFromReader(w.Body)
require.NoError(t, err)
response.AssertOk(t)
// Verify the response contains tags with bookmark_count field
var tags []model.TagDTO
responseData, err := json.Marshal(response.Response.GetMessage())
require.NoError(t, err)
err = json.Unmarshal(responseData, &tags)
require.NoError(t, err)
require.NotEmpty(t, tags)
// The bookmark_count field should be present in the response
// Even if it's 0, it should be included when the parameter is set
for _, tag := range tags {
if tag.Name == "test-tag-with-count" {
// We're just checking that the field exists and is accessible
_ = tag.BookmarkCount
break
}
}
})
t.Run("invalid bookmark_id parameter", func(t *testing.T) {
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
w := testutil.PerformRequest(
deps,
HandleListTags,
"GET",
"/api/v1/tags",
testutil.WithFakeUser(),
testutil.WithRequestQueryParam("bookmark_id", "invalid"),
)
require.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("bookmark_id parameter", func(t *testing.T) {
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
// Create a test bookmark
bookmark := testutil.GetValidBookmark()
bookmarks, err := deps.Database().SaveBookmarks(ctx, true, *bookmark)
require.NoError(t, err)
require.Len(t, bookmarks, 1)
bookmarkID := bookmarks[0].ID
// Create a test tag
tag := model.Tag{Name: "test-tag-for-bookmark"}
createdTags, err := deps.Database().CreateTags(ctx, tag)
require.NoError(t, err)
require.Len(t, createdTags, 1)
// Associate the tag with the bookmark
err = deps.Database().BulkUpdateBookmarkTags(ctx, []int{bookmarkID}, []int{createdTags[0].ID})
require.NoError(t, err)
w := testutil.PerformRequest(
deps,
HandleListTags,
"GET",
"/api/v1/tags",
testutil.WithFakeUser(),
testutil.WithRequestQueryParam("bookmark_id", strconv.Itoa(bookmarkID)),
)
require.Equal(t, http.StatusOK, w.Code)
response, err := testutil.NewTestResponseFromReader(w.Body)
require.NoError(t, err)
response.AssertOk(t)
// Verify the response contains the tag associated with the bookmark
var tags []model.TagDTO
responseData, err := json.Marshal(response.Response.GetMessage())
require.NoError(t, err)
err = json.Unmarshal(responseData, &tags)
require.NoError(t, err)
// Check that we have at least one tag and it's the one we created
require.NotEmpty(t, tags)
found := false
for _, t := range tags {
if t.Name == "test-tag-for-bookmark" {
found = true
break
}
}
require.True(t, found, "The tag associated with the bookmark should be in the response")
})
}
func TestHandleGetTag(t *testing.T) {

View file

@ -6,6 +6,8 @@ import (
"github.com/jmoiron/sqlx"
)
type DBID int
// DB is interface for accessing and manipulating data in database.
type DB interface {
// WriterDB is the underlying sqlx.DB
@ -14,6 +16,9 @@ type DB interface {
// ReaderDB is the underlying sqlx.DB
ReaderDB() *sqlx.DB
// Flavor is the flavor of the database
// Flavor() sqlbuilder.Flavor
// Init initializes the database
Init(ctx context.Context) error
@ -67,7 +72,7 @@ type DB interface {
CreateTag(ctx context.Context, tag Tag) (Tag, error)
// GetTags fetch list of tags and its frequency from database.
GetTags(ctx context.Context) ([]TagDTO, error)
GetTags(ctx context.Context, opts DBListTagsOptions) ([]TagDTO, error)
// RenameTag change the name of a tag.
RenameTag(ctx context.Context, id int, newName string) error
@ -122,8 +127,15 @@ type DBListAccountsOptions struct {
WithPassword bool
}
type DBTagOrderBy string
const (
DBTagOrderByTagName DBTagOrderBy = "name"
)
// DBListTagsOptions is options for fetching tags from database.
type DBListTagsOptions struct {
BookmarkID int
WithBookmarkCount bool
OrderBy DBTagOrderBy
}

View file

@ -1,3 +0,0 @@
package model
type DBID int

View file

@ -48,7 +48,7 @@ type StorageDomain interface {
}
type TagsDomain interface {
ListTags(ctx context.Context) ([]TagDTO, error)
ListTags(ctx context.Context, opts ListTagsOptions) ([]TagDTO, error)
CreateTag(ctx context.Context, tag TagDTO) (TagDTO, error)
GetTag(ctx context.Context, id int) (TagDTO, error)
UpdateTag(ctx context.Context, tag TagDTO) (TagDTO, error)

View file

@ -34,3 +34,10 @@ func (t *TagDTO) ToTag() Tag {
Name: t.Name,
}
}
// ListTagsOptions is options for fetching tags from database.
type ListTagsOptions struct {
BookmarkID int
WithBookmarkCount bool
OrderBy DBTagOrderBy
}

View file

@ -61,12 +61,22 @@ func WithFakeAccount(isAdmin bool) Option {
}
}
// WithRequestPathValue adds a path value to the request
func WithRequestPathValue(key, value string) Option {
return func(c model.WebContext) {
c.Request().SetPathValue(key, value)
}
}
// WithRequestQueryParam adds a query parameter to the request
func WithRequestQueryParam(key, value string) Option {
return func(c model.WebContext) {
q := c.Request().URL.Query()
q.Add(key, value)
c.Request().URL.RawQuery = q.Encode()
}
}
// PerformRequest executes a request against a handler
func PerformRequest(deps model.Dependencies, handler model.HttpHandler, method, path string, options ...Option) *httptest.ResponseRecorder {
w := httptest.NewRecorder()

View file

@ -81,7 +81,7 @@ var template = `
<a @click="filterTag('*')">(all tagged)</a>
<a @click="filterTag('*', true)">(all untagged)</a>
<a v-for="tag in tags" @click="dialogTagClicked($event, tag)">
#{{tag.name}}<span>{{tag.nBookmarks}}</span>
#{{tag.name}}<span>{{tag.bookmark_count}}</span>
</a>
</custom-dialog>
<custom-dialog v-bind="dialog"/>
@ -257,12 +257,16 @@ export default {
// Fetch tags if requested
if (fetchTags) {
return fetch(new URL("api/v1/tags", document.baseURI), {
headers: {
"Content-Type": "application/json",
Authorization: "Bearer " + localStorage.getItem("shiori-token"),
return fetch(
new URL("api/v1/tags?with_bookmark_count=true", document.baseURI),
{
headers: {
"Content-Type": "application/json",
Authorization:
"Bearer " + localStorage.getItem("shiori-token"),
},
},
});
);
} else {
this.loading = false;
throw skipFetchTags;

View file

@ -131,7 +131,9 @@ func (h *Handler) ApiGetTags(w http.ResponseWriter, r *http.Request, ps httprout
checkError(err)
// Fetch all tags
tags, err := h.DB.GetTags(ctx)
tags, err := h.DB.GetTags(ctx, model.DBListTagsOptions{
WithBookmarkCount: true,
})
checkError(err)
w.Header().Set("Content-Type", "application/json")