mirror of
https://github.com/go-shiori/shiori.git
synced 2025-10-02 09:46:59 +08:00
* refactor: base http server stdlib * refactor: swagger and frontend routes * fix: use global middlewares * refactor: removed gin from testutils * fix: object references in legacy webserver * refactor: legacy, swagger and system handlers * fix: added verbs to handlers * fix: server handlers ordering * refactor: bookmarks handlers * refactor: system api routes * tests: bookmark handlers * refactor: migrated api auth routes * chore: remove unused middlewares * docs: add swagger docs to refactored system api * chore: remove old auth routes * refactor: account apis * chore: removed old handlers * fix: api v1 handlers missing middlewares * refactor: migrated tag list route * refactor: bookmark routes * refactor: remove gin * chore: make styles * test: fixed tests * test: generate binary file without text * fix: global middleware missing from system api handler * fix: incorrect api handler * chore: avoid logging screenshot contents * tests: bookmarks domain * tests: shortcuts * test: missing tests * tests: server tests * test: remove test using syscall to avoid windows errors * chore: added middlewares
66 lines
1.4 KiB
Go
66 lines
1.4 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/go-shiori/shiori/internal/model"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// ErrNotFound is error returned when record is not found in database.
|
|
var ErrNotFound = errors.New("not found")
|
|
|
|
// ErrAlreadyExists is error returned when record already exists in database.
|
|
var ErrAlreadyExists = errors.New("already exists")
|
|
|
|
// Connect connects to database based on submitted database URL.
|
|
func Connect(ctx context.Context, dbURL string) (model.DB, error) {
|
|
dbU, err := url.Parse(dbURL)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to parse database URL")
|
|
}
|
|
|
|
switch dbU.Scheme {
|
|
case "mysql":
|
|
urlNoSchema := strings.Split(dbURL, "://")[1]
|
|
return OpenMySQLDatabase(ctx, urlNoSchema)
|
|
case "postgres":
|
|
return OpenPGDatabase(ctx, dbURL)
|
|
case "sqlite":
|
|
return OpenSQLiteDatabase(ctx, dbU.Path[1:])
|
|
}
|
|
|
|
return nil, fmt.Errorf("unsupported database scheme: %s", dbU.Scheme)
|
|
}
|
|
|
|
type dbbase struct {
|
|
*sqlx.DB
|
|
}
|
|
|
|
func (db *dbbase) withTx(ctx context.Context, fn func(tx *sqlx.Tx) error) error {
|
|
tx, err := db.BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
defer func() {
|
|
if err := tx.Commit(); err != nil {
|
|
log.Printf("error during commit: %s", err)
|
|
}
|
|
}()
|
|
|
|
err = fn(tx)
|
|
if err != nil {
|
|
if err := tx.Rollback(); err != nil {
|
|
log.Printf("error during rollback: %s", err)
|
|
}
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return err
|
|
}
|