shiori/internal/database/pg.go

619 lines
15 KiB
Go
Raw Normal View History

2019-09-25 00:59:25 +08:00
package database
import (
"database/sql"
"fmt"
2022-02-13 23:38:27 +08:00
"log"
2019-09-25 00:59:25 +08:00
"strings"
"time"
"github.com/go-shiori/shiori/internal/model"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
2019-09-25 00:59:25 +08:00
"github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
)
// PGDatabase is implementation of Database interface
// for connecting to PostgreSQL database.
type PGDatabase struct {
sqlx.DB
}
// OpenPGDatabase creates and opens connection to a PostgreSQL Database.
func OpenPGDatabase(connString string) (pgDB *PGDatabase, err error) {
// Open database and start transaction
db := sqlx.MustConnect("postgres", connString)
db.SetMaxOpenConns(100)
db.SetConnMaxLifetime(time.Second)
2019-09-25 00:59:25 +08:00
pgDB = &PGDatabase{*db}
return pgDB, err
}
// Migrate runs migrations for this database engine
func (db *PGDatabase) Migrate() error {
sourceDriver, err := iofs.New(migrations, "migrations/postgres")
checkError(err)
dbDriver, err := postgres.WithInstance(db.DB.DB, &postgres.Config{})
checkError(err)
migration, err := migrate.NewWithInstance(
"iofs",
sourceDriver,
"postgres",
dbDriver,
)
checkError(err)
return migration.Up()
}
2019-09-25 00:59:25 +08:00
// SaveBookmarks saves new or updated bookmarks to database.
// Returns the saved ID and error message if any happened.
func (db *PGDatabase) SaveBookmarks(bookmarks ...model.Bookmark) (result []model.Bookmark, err error) {
// Prepare transaction
tx, err := db.Beginx()
if err != nil {
return []model.Bookmark{}, err
}
// Make sure to rollback if panic ever happened
defer func() {
if r := recover(); r != nil {
panicErr, _ := r.(error)
2022-02-13 23:38:27 +08:00
if err := tx.Rollback(); err != nil {
log.Printf("error during rollback: %s", err)
}
2019-09-25 00:59:25 +08:00
result = []model.Bookmark{}
err = panicErr
}
}()
// Prepare statement
stmtInsertBook, err := tx.Preparex(`INSERT INTO bookmark
(url, title, excerpt, author, public, content, html, modified)
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT(url) DO UPDATE SET
url = $1,
title = $2,
excerpt = $3,
author = $4,
public = $5,
content = $6,
html = $7,
modified = $8`)
checkError(err)
stmtGetTag, err := tx.Preparex(`SELECT id FROM tag WHERE name = $1`)
checkError(err)
stmtInsertTag, err := tx.Preparex(`INSERT INTO tag (name) VALUES ($1) RETURNING id`)
checkError(err)
stmtInsertBookTag, err := tx.Preparex(`INSERT INTO bookmark_tag
(tag_id, bookmark_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`)
checkError(err)
stmtDeleteBookTag, err := tx.Preparex(`DELETE FROM bookmark_tag
WHERE bookmark_id = $1 AND tag_id = $2`)
checkError(err)
// Prepare modified time
modifiedTime := time.Now().UTC().Format("2006-01-02 15:04:05")
// Execute statements
result = []model.Bookmark{}
for _, book := range bookmarks {
// Check ID, URL and title
if book.ID == 0 {
panic(fmt.Errorf("ID must not be empty"))
}
if book.URL == "" {
panic(fmt.Errorf("URL must not be empty"))
}
if book.Title == "" {
panic(fmt.Errorf("title must not be empty"))
}
// Set modified time
book.Modified = modifiedTime
// Save bookmark
stmtInsertBook.MustExec(
book.URL, book.Title, book.Excerpt, book.Author,
book.Public, book.Content, book.HTML, book.Modified)
// Save book tags
newTags := []model.Tag{}
for _, tag := range book.Tags {
// If it's deleted tag, delete and continue
if tag.Deleted {
stmtDeleteBookTag.MustExec(book.ID, tag.ID)
continue
}
// Normalize tag name
tagName := strings.ToLower(tag.Name)
tagName = strings.Join(strings.Fields(tagName), " ")
// If tag doesn't have any ID, fetch it from database
if tag.ID == 0 {
err = stmtGetTag.Get(&tag.ID, tagName)
checkError(err)
// If tag doesn't exist in database, save it
if tag.ID == 0 {
var tagID64 int64
err = stmtInsertTag.Get(&tagID64, tagName)
checkError(err)
tag.ID = int(tagID64)
}
2022-02-13 23:38:27 +08:00
if _, err := stmtInsertBookTag.Exec(tag.ID, book.ID); err != nil {
log.Printf("error during insert: %s", err)
}
2019-09-25 00:59:25 +08:00
}
newTags = append(newTags, tag)
}
book.Tags = newTags
result = append(result, book)
}
// Commit transaction
err = tx.Commit()
checkError(err)
return result, err
}
// GetBookmarks fetch list of bookmarks based on submitted options.
func (db *PGDatabase) GetBookmarks(opts GetBookmarksOptions) ([]model.Bookmark, error) {
// Create initial query
columns := []string{
`id`,
`url`,
`title`,
`excerpt`,
`author`,
`public`,
`modified`,
`content <> '' has_content`}
if opts.WithContent {
columns = append(columns, `content`, `html`)
}
query := `SELECT ` + strings.Join(columns, ",") + `
FROM bookmark WHERE TRUE`
// Add where clause
arg := map[string]interface{}{}
// Add where clause for IDs
if len(opts.IDs) > 0 {
query += ` AND id IN (:ids)`
arg["ids"] = opts.IDs
}
// Add where clause for search keyword
if opts.Keyword != "" {
query += ` AND (
2022-02-13 23:38:27 +08:00
url LIKE :lkw OR
2019-09-25 00:59:25 +08:00
MATCH(title, excerpt, content) AGAINST (:kw IN BOOLEAN MODE)
)`
2019-10-03 22:47:03 +08:00
arg["lkw"] = "%" + opts.Keyword + "%"
2019-09-25 00:59:25 +08:00
arg["kw"] = opts.Keyword
}
// Add where clause for tags.
// First we check for * in excluded and included tags,
// which means all tags will be excluded and included, respectively.
excludeAllTags := false
for _, excludedTag := range opts.ExcludedTags {
if excludedTag == "*" {
excludeAllTags = true
opts.ExcludedTags = []string{}
break
}
}
includeAllTags := false
for _, includedTag := range opts.Tags {
if includedTag == "*" {
includeAllTags = true
opts.Tags = []string{}
break
}
}
// If all tags excluded, we will only show bookmark without tags.
// In other hand, if all tags included, we will only show bookmark with tags.
if excludeAllTags {
query += ` AND id NOT IN (SELECT DISTINCT bookmark_id FROM bookmark_tag)`
} else if includeAllTags {
query += ` AND id IN (SELECT DISTINCT bookmark_id FROM bookmark_tag)`
}
// Now we only need to find the normal tags
if len(opts.Tags) > 0 {
query += ` AND id IN (
SELECT bt.bookmark_id
FROM bookmark_tag bt
LEFT JOIN tag t ON bt.tag_id = t.id
WHERE t.name IN(:tags)
GROUP BY bt.bookmark_id
HAVING COUNT(bt.bookmark_id) = :ltags)`
arg["tags"] = opts.Tags
arg["ltags"] = len(opts.Tags)
}
if len(opts.ExcludedTags) > 0 {
query += ` AND id NOT IN (
SELECT DISTINCT bt.bookmark_id
FROM bookmark_tag bt
LEFT JOIN tag t ON bt.tag_id = t.id
WHERE t.name IN(:extags))`
arg["extags"] = opts.ExcludedTags
}
// Add order clause
switch opts.OrderMethod {
case ByLastAdded:
query += ` ORDER BY id DESC`
case ByLastModified:
query += ` ORDER BY modified DESC`
default:
query += ` ORDER BY id`
}
if opts.Limit > 0 && opts.Offset >= 0 {
query += ` LIMIT :limit OFFSET :offset`
arg["limit"] = opts.Limit
arg["offset"] = opts.Offset
}
// Expand query, because some of the args might be an array
2022-02-13 23:38:27 +08:00
var err error
query, args, _ := sqlx.Named(query, arg)
2019-09-25 00:59:25 +08:00
query, args, err = sqlx.In(query, args...)
if err != nil {
return nil, fmt.Errorf("failed to expand query: %v", err)
}
query = db.Rebind(query)
// Fetch bookmarks
bookmarks := []model.Bookmark{}
err = db.Select(&bookmarks, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch data: %v", err)
}
// Fetch tags for each bookmarks
2022-02-13 23:38:27 +08:00
stmtGetTags, err := db.Preparex(`SELECT t.id, t.name
FROM bookmark_tag bt
2019-09-25 00:59:25 +08:00
LEFT JOIN tag t ON bt.tag_id = t.id
2022-02-13 23:38:27 +08:00
WHERE bt.bookmark_id = $1
2019-09-25 00:59:25 +08:00
ORDER BY t.name`)
if err != nil {
return nil, fmt.Errorf("failed to prepare tag query: %v", err)
}
defer stmtGetTags.Close()
for i, book := range bookmarks {
book.Tags = []model.Tag{}
err = stmtGetTags.Select(&book.Tags, book.ID)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch tags: %v", err)
}
bookmarks[i] = book
}
return bookmarks, nil
}
// GetBookmarksCount fetch count of bookmarks based on submitted options.
func (db *PGDatabase) GetBookmarksCount(opts GetBookmarksOptions) (int, error) {
// Create initial query
query := `SELECT COUNT(id) FROM bookmark WHERE TRUE`
arg := map[string]interface{}{}
// Add where clause for IDs
if len(opts.IDs) > 0 {
query += ` AND id IN (:ids)`
arg["ids"] = opts.IDs
}
// Add where clause for search keyword
if opts.Keyword != "" {
query += ` AND (
2022-02-13 23:38:27 +08:00
url LIKE :lurl OR
2019-09-25 00:59:25 +08:00
MATCH(title, excerpt, content) AGAINST (:kw IN BOOLEAN MODE)
)`
2019-10-03 22:47:03 +08:00
arg["lurl"] = "%" + opts.Keyword + "%"
2019-09-25 00:59:25 +08:00
arg["kw"] = opts.Keyword
}
// Add where clause for tags.
// First we check for * in excluded and included tags,
// which means all tags will be excluded and included, respectively.
excludeAllTags := false
for _, excludedTag := range opts.ExcludedTags {
if excludedTag == "*" {
excludeAllTags = true
opts.ExcludedTags = []string{}
break
}
}
includeAllTags := false
for _, includedTag := range opts.Tags {
if includedTag == "*" {
includeAllTags = true
opts.Tags = []string{}
break
}
}
// If all tags excluded, we will only show bookmark without tags.
// In other hand, if all tags included, we will only show bookmark with tags.
if excludeAllTags {
query += ` AND id NOT IN (SELECT DISTINCT bookmark_id FROM bookmark_tag)`
} else if includeAllTags {
query += ` AND id IN (SELECT DISTINCT bookmark_id FROM bookmark_tag)`
}
// Now we only need to find the normal tags
if len(opts.Tags) > 0 {
query += ` AND id IN (
SELECT bt.bookmark_id
FROM bookmark_tag bt
LEFT JOIN tag t ON bt.tag_id = t.id
WHERE t.name IN(:tags)
GROUP BY bt.bookmark_id
HAVING COUNT(bt.bookmark_id) = :ltags)`
arg["tags"] = opts.Tags
arg["ltags"] = len(opts.Tags)
}
if len(opts.ExcludedTags) > 0 {
query += ` AND id NOT IN (
SELECT DISTINCT bt.bookmark_id
FROM bookmark_tag bt
LEFT JOIN tag t ON bt.tag_id = t.id
WHERE t.name IN(:etags))`
arg["etags"] = opts.ExcludedTags
}
// Expand query, because some of the args might be an array
2022-02-13 23:38:27 +08:00
var err error
query, args, _ := sqlx.Named(query, arg)
2019-09-25 00:59:25 +08:00
query, args, err = sqlx.In(query, args...)
if err != nil {
return 0, fmt.Errorf("failed to expand query: %v", err)
}
query = db.Rebind(query)
// Fetch count
var nBookmarks int
err = db.Get(&nBookmarks, query, args...)
if err != nil && err != sql.ErrNoRows {
return 0, fmt.Errorf("failed to fetch count: %v", err)
}
return nBookmarks, nil
}
// DeleteBookmarks removes all record with matching ids from database.
func (db *PGDatabase) DeleteBookmarks(ids ...int) (err error) {
// Begin transaction
tx, err := db.Beginx()
if err != nil {
return err
}
// Make sure to rollback if panic ever happened
defer func() {
if r := recover(); r != nil {
panicErr, _ := r.(error)
2022-02-13 23:38:27 +08:00
if err := tx.Rollback(); err != nil {
log.Printf("error during rollback: %s", err)
}
2019-09-25 00:59:25 +08:00
err = panicErr
}
}()
// Prepare queries
delBookmark := `DELETE FROM bookmark`
delBookmarkTag := `DELETE FROM bookmark_tag`
// Delete bookmark(s)
if len(ids) == 0 {
tx.MustExec(delBookmarkTag)
tx.MustExec(delBookmark)
} else {
delBookmark += ` WHERE id = $1`
delBookmarkTag += ` WHERE bookmark_id = $1`
stmtDelBookmark, _ := tx.Preparex(delBookmark)
stmtDelBookmarkTag, _ := tx.Preparex(delBookmarkTag)
for _, id := range ids {
stmtDelBookmarkTag.MustExec(id)
stmtDelBookmark.MustExec(id)
}
}
// Commit transaction
err = tx.Commit()
checkError(err)
return err
}
// GetBookmark fetchs bookmark based on its ID or URL.
// Returns the bookmark and boolean whether it's exist or not.
func (db *PGDatabase) GetBookmark(id int, url string) (model.Bookmark, bool) {
args := []interface{}{id}
query := `SELECT
2022-02-13 23:38:27 +08:00
id, url, title, excerpt, author, public,
2019-10-03 22:47:03 +08:00
content, html, modified, content <> '' has_content
2019-09-25 00:59:25 +08:00
FROM bookmark WHERE id = $1`
if url != "" {
query += ` OR url = $2`
args = append(args, url)
}
book := model.Bookmark{}
2022-02-13 23:38:27 +08:00
if err := db.Get(&book, query, args...); err != nil {
log.Printf("error during db.get: %s", err)
}
2019-09-25 00:59:25 +08:00
return book, book.ID != 0
}
// SaveAccount saves new account to database. Returns error if any happened.
func (db *PGDatabase) SaveAccount(account model.Account) (err error) {
// Hash password with bcrypt
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(account.Password), 10)
if err != nil {
return err
}
// Insert account to database
_, err = db.Exec(`INSERT INTO account
(username, password, owner) VALUES ($1, $2, $3)
ON CONFLICT(username) DO UPDATE SET
password = $2,
owner = $3`,
account.Username, hashedPassword, account.Owner)
return err
}
// GetAccounts fetch list of account (without its password) based on submitted options.
func (db *PGDatabase) GetAccounts(opts GetAccountsOptions) ([]model.Account, error) {
// Create query
args := []interface{}{}
query := `SELECT id, username, owner FROM account WHERE TRUE`
if opts.Keyword != "" {
query += " AND username LIKE $1"
args = append(args, "%"+opts.Keyword+"%")
}
if opts.Owner {
query += " AND owner = TRUE"
}
query += ` ORDER BY username`
// Fetch list account
accounts := []model.Account{}
err := db.Select(&accounts, query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch accounts: %v", err)
}
return accounts, nil
}
// GetAccount fetch account with matching username.
// Returns the account and boolean whether it's exist or not.
func (db *PGDatabase) GetAccount(username string) (model.Account, bool) {
account := model.Account{}
2022-02-13 23:38:27 +08:00
if err := db.Get(&account, `SELECT
2019-09-25 00:59:25 +08:00
id, username, password, owner FROM account WHERE username = $1`,
2022-02-13 23:38:27 +08:00
username,
); err != nil {
log.Printf("error during db.get: %s", err)
}
2019-10-03 22:47:03 +08:00
2019-09-25 00:59:25 +08:00
return account, account.ID != 0
}
// DeleteAccounts removes all record with matching usernames.
func (db *PGDatabase) DeleteAccounts(usernames ...string) (err error) {
// Begin transaction
tx, err := db.Beginx()
if err != nil {
return err
}
// Make sure to rollback if panic ever happened
defer func() {
if r := recover(); r != nil {
panicErr, _ := r.(error)
2022-02-13 23:38:27 +08:00
if err := tx.Rollback(); err != nil {
log.Printf("error during rollback: %s", err)
}
2019-09-25 00:59:25 +08:00
err = panicErr
}
}()
// Delete account
stmtDelete, _ := tx.Preparex(`DELETE FROM account WHERE username = $1`)
for _, username := range usernames {
stmtDelete.MustExec(username)
}
// Commit transaction
err = tx.Commit()
checkError(err)
return err
}
// GetTags fetch list of tags and their frequency.
func (db *PGDatabase) GetTags() ([]model.Tag, error) {
tags := []model.Tag{}
2022-02-13 23:38:27 +08:00
query := `SELECT bt.tag_id id, t.name, COUNT(bt.tag_id) n_bookmarks
FROM bookmark_tag bt
2019-09-25 00:59:25 +08:00
LEFT JOIN tag t ON bt.tag_id = t.id
GROUP BY bt.tag_id, t.name ORDER BY t.name`
err := db.Select(&tags, query)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch tags: %v", err)
}
return tags, nil
}
// RenameTag change the name of a tag.
func (db *PGDatabase) RenameTag(id int, newName string) error {
_, err := db.Exec(`UPDATE tag SET name = $1 WHERE id = $2`, newName, id)
return err
}
// CreateNewID creates new ID for specified table
func (db *PGDatabase) CreateNewID(table string) (int, error) {
2019-10-03 22:47:03 +08:00
var tableID int
2019-09-25 00:59:25 +08:00
query := fmt.Sprintf(`SELECT last_value from %s_id_seq;`, table)
err := db.Get(&tableID, query)
if err != nil && err != sql.ErrNoRows {
return -1, err
}
return tableID, nil
}