feat: store/db module with sqlite

This commit is contained in:
boojack 2022-05-22 00:59:22 +08:00
parent c34cbb19bc
commit 8e01eb8702
19 changed files with 228 additions and 40 deletions

View file

@ -7,6 +7,7 @@ import (
"memos/common"
"memos/server"
"memos/store"
DB "memos/store/db"
)
const (
@ -25,14 +26,14 @@ type Main struct {
}
func (m *Main) Run() error {
db := store.NewDB(m.profile)
db := DB.NewDB(m.profile)
if err := db.Open(); err != nil {
return fmt.Errorf("cannot open db: %w", err)
}
s := server.NewServer(m.profile)
storeInstance := store.New(db)
storeInstance := store.New(db.Db, m.profile)
s.Store = storeInstance
if err := s.Run(); err != nil {

4
go.mod
View file

@ -12,7 +12,7 @@ require (
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.1 // indirect
golang.org/x/crypto v0.0.0-20210920023735-84f357641f63 // indirect
golang.org/x/crypto v0.0.0-20210920023735-84f357641f63
golang.org/x/net v0.0.0-20210917221730-978cfadd31cf // indirect
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect
golang.org/x/text v0.3.7 // indirect
@ -26,7 +26,7 @@ require (
)
require (
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1
github.com/labstack/echo-contrib v0.12.0
)

View file

@ -1,4 +1,4 @@
package store
package db
import (
"database/sql"
@ -76,6 +76,33 @@ func (db *DB) Open() (err error) {
}
func (db *DB) migrate() error {
table, err := findTable(db, "migration_history")
if err != nil {
return err
}
if table == nil {
createTable(db, `
CREATE TABLE migration_history (
version TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now'))
);
`)
}
migrationHistoryList, err := findMigrationHistoyList(db)
if err != nil {
return err
}
if len(migrationHistoryList) == 0 {
createMigrationHistoy(db, common.Version)
} else {
migrationHistory := migrationHistoryList[0]
if migrationHistory.Version != common.Version {
createMigrationHistoy(db, common.Version)
}
}
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/*.sql", "migration"))
if err != nil {
return err

View file

@ -0,0 +1,64 @@
package db
import (
"fmt"
"time"
)
type MigrationHistory struct {
CreatedTs int64
Version string
}
func findMigrationHistoyList(db *DB) ([]*MigrationHistory, error) {
rows, err := db.Db.Query(`
SELECT
version,
created_ts
FROM
migration_history
ORDER BY created_ts DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
migrationHistoryList := make([]*MigrationHistory, 0)
for rows.Next() {
var migrationHistory MigrationHistory
if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
migrationHistoryList = append(migrationHistoryList, &migrationHistory)
}
return migrationHistoryList, nil
}
func createMigrationHistoy(db *DB, version string) error {
result, err := db.Db.Exec(`
INSERT INTO migration_history (
version,
created_ts
)
VALUES (?, ?)
`,
version,
time.Now().Unix(),
)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("failed to create migration history with %s", version)
}
return nil
}

65
store/db/table.go Normal file
View file

@ -0,0 +1,65 @@
package db
import (
"fmt"
"strings"
)
type Table struct {
Name string
SQL string
}
func findTable(db *DB, tableName string) (*Table, error) {
where, args := []string{"1 = 1"}, []interface{}{}
where, args = append(where, "type = ?"), append(args, "table")
where, args = append(where, "name = ?"), append(args, tableName)
rows, err := db.Db.Query(`
SELECT
tbl_name,
sql
FROM sqlite_schema
WHERE `+strings.Join(where, " AND "),
args...,
)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
tableList := make([]*Table, 0)
for rows.Next() {
var table Table
if err := rows.Scan(
&table.Name,
&table.SQL,
); err != nil {
return nil, FormatError(err)
}
tableList = append(tableList, &table)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
if len(tableList) == 0 {
return nil, nil
} else {
return tableList[0], nil
}
}
func createTable(db *DB, sql string) error {
result, err := db.Db.Exec(sql)
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("failed to create table with %s", sql)
}
return err
}

19
store/error.go Normal file
View file

@ -0,0 +1,19 @@
package store
import (
"database/sql"
"errors"
)
func FormatError(err error) error {
if err == nil {
return nil
}
switch err {
case sql.ErrNoRows:
return errors.New("data not found")
default:
return err
}
}

View file

@ -1,6 +1,7 @@
package store
import (
"database/sql"
"fmt"
"memos/api"
"memos/common"
@ -113,7 +114,7 @@ func (s *Store) DeleteMemo(delete *api.MemoDelete) error {
return nil
}
func createMemoRaw(db *DB, create *api.MemoCreate) (*memoRaw, error) {
func createMemoRaw(db *sql.DB, create *api.MemoCreate) (*memoRaw, error) {
set := []string{"creator_id", "content"}
placeholder := []string{"?", "?"}
args := []interface{}{create.CreatorID, create.Content}
@ -122,7 +123,7 @@ func createMemoRaw(db *DB, create *api.MemoCreate) (*memoRaw, error) {
set, placeholder, args = append(set, "created_ts"), append(placeholder, "?"), append(args, *v)
}
row, err := db.Db.Query(`
row, err := db.Query(`
INSERT INTO memo (
`+strings.Join(set, ", ")+`
)
@ -152,7 +153,7 @@ func createMemoRaw(db *DB, create *api.MemoCreate) (*memoRaw, error) {
return &memoRaw, nil
}
func patchMemoRaw(db *DB, patch *api.MemoPatch) (*memoRaw, error) {
func patchMemoRaw(db *sql.DB, patch *api.MemoPatch) (*memoRaw, error) {
set, args := []string{}, []interface{}{}
if v := patch.Content; v != nil {
@ -164,7 +165,7 @@ func patchMemoRaw(db *DB, patch *api.MemoPatch) (*memoRaw, error) {
args = append(args, patch.ID)
row, err := db.Db.Query(`
row, err := db.Query(`
UPDATE memo
SET `+strings.Join(set, ", ")+`
WHERE id = ?
@ -193,7 +194,7 @@ func patchMemoRaw(db *DB, patch *api.MemoPatch) (*memoRaw, error) {
return &memoRaw, nil
}
func findMemoRawList(db *DB, find *api.MemoFind) ([]*memoRaw, error) {
func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -209,7 +210,7 @@ func findMemoRawList(db *DB, find *api.MemoFind) ([]*memoRaw, error) {
where = append(where, "id in (SELECT memo_id FROM memo_organizer WHERE pinned = 1 AND user_id = memo.creator_id )")
}
rows, err := db.Db.Query(`
rows, err := db.Query(`
SELECT
id,
creator_id,
@ -250,8 +251,8 @@ func findMemoRawList(db *DB, find *api.MemoFind) ([]*memoRaw, error) {
return memoRawList, nil
}
func deleteMemo(db *DB, delete *api.MemoDelete) error {
result, err := db.Db.Exec(`DELETE FROM memo WHERE id = ?`, delete.ID)
func deleteMemo(db *sql.DB, delete *api.MemoDelete) error {
result, err := db.Exec(`DELETE FROM memo WHERE id = ?`, delete.ID)
if err != nil {
return FormatError(err)
}

View file

@ -1,6 +1,7 @@
package store
import (
"database/sql"
"fmt"
"memos/api"
"memos/common"
@ -47,8 +48,8 @@ func (s *Store) UpsertMemoOrganizer(upsert *api.MemoOrganizerUpsert) error {
return nil
}
func findMemoOrganizer(db *DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) {
row, err := db.Db.Query(`
func findMemoOrganizer(db *sql.DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) {
row, err := db.Query(`
SELECT
id,
memo_id,
@ -79,8 +80,8 @@ func findMemoOrganizer(db *DB, find *api.MemoOrganizerFind) (*memoOrganizerRaw,
return &memoOrganizerRaw, nil
}
func upsertMemoOrganizer(db *DB, upsert *api.MemoOrganizerUpsert) error {
row, err := db.Db.Query(`
func upsertMemoOrganizer(db *sql.DB, upsert *api.MemoOrganizerUpsert) error {
row, err := db.Query(`
INSERT INTO memo_organizer (
memo_id,
user_id,

View file

@ -1,6 +1,7 @@
package store
import (
"database/sql"
"fmt"
"memos/api"
"memos/common"
@ -90,8 +91,8 @@ func (s *Store) DeleteResource(delete *api.ResourceDelete) error {
return nil
}
func createResource(db *DB, create *api.ResourceCreate) (*resourceRaw, error) {
row, err := db.Db.Query(`
func createResource(db *sql.DB, create *api.ResourceCreate) (*resourceRaw, error) {
row, err := db.Query(`
INSERT INTO resource (
filename,
blob,
@ -130,7 +131,7 @@ func createResource(db *DB, create *api.ResourceCreate) (*resourceRaw, error) {
return &resourceRaw, nil
}
func findResourceList(db *DB, find *api.ResourceFind) ([]*resourceRaw, error) {
func findResourceList(db *sql.DB, find *api.ResourceFind) ([]*resourceRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -143,7 +144,7 @@ func findResourceList(db *DB, find *api.ResourceFind) ([]*resourceRaw, error) {
where, args = append(where, "filename = ?"), append(args, *v)
}
rows, err := db.Db.Query(`
rows, err := db.Query(`
SELECT
id,
filename,
@ -186,8 +187,8 @@ func findResourceList(db *DB, find *api.ResourceFind) ([]*resourceRaw, error) {
return resourceRawList, nil
}
func deleteResource(db *DB, delete *api.ResourceDelete) error {
result, err := db.Db.Exec(`DELETE FROM resource WHERE id = ?`, delete.ID)
func deleteResource(db *sql.DB, delete *api.ResourceDelete) error {
result, err := db.Exec(`DELETE FROM resource WHERE id = ?`, delete.ID)
if err != nil {
return FormatError(err)
}

View file

@ -1,6 +1,7 @@
package store
import (
"database/sql"
"fmt"
"memos/api"
"memos/common"
@ -97,8 +98,8 @@ func (s *Store) DeleteShortcut(delete *api.ShortcutDelete) error {
return nil
}
func createShortcut(db *DB, create *api.ShortcutCreate) (*shortcutRaw, error) {
row, err := db.Db.Query(`
func createShortcut(db *sql.DB, create *api.ShortcutCreate) (*shortcutRaw, error) {
row, err := db.Query(`
INSERT INTO shortcut (
title,
payload,
@ -133,7 +134,7 @@ func createShortcut(db *DB, create *api.ShortcutCreate) (*shortcutRaw, error) {
return &shortcutRaw, nil
}
func patchShortcut(db *DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
func patchShortcut(db *sql.DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
set, args := []string{}, []interface{}{}
if v := patch.Title; v != nil {
@ -148,7 +149,7 @@ func patchShortcut(db *DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
args = append(args, patch.ID)
row, err := db.Db.Query(`
row, err := db.Query(`
UPDATE shortcut
SET `+strings.Join(set, ", ")+`
WHERE id = ?
@ -178,7 +179,7 @@ func patchShortcut(db *DB, patch *api.ShortcutPatch) (*shortcutRaw, error) {
return &shortcutRaw, nil
}
func findShortcutList(db *DB, find *api.ShortcutFind) ([]*shortcutRaw, error) {
func findShortcutList(db *sql.DB, find *api.ShortcutFind) ([]*shortcutRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -191,7 +192,7 @@ func findShortcutList(db *DB, find *api.ShortcutFind) ([]*shortcutRaw, error) {
where, args = append(where, "title = ?"), append(args, *v)
}
rows, err := db.Db.Query(`
rows, err := db.Query(`
SELECT
id,
title,
@ -234,8 +235,8 @@ func findShortcutList(db *DB, find *api.ShortcutFind) ([]*shortcutRaw, error) {
return shortcutRawList, nil
}
func deleteShortcut(db *DB, delete *api.ShortcutDelete) error {
result, err := db.Db.Exec(`DELETE FROM shortcut WHERE id = ?`, delete.ID)
func deleteShortcut(db *sql.DB, delete *api.ShortcutDelete) error {
result, err := db.Exec(`DELETE FROM shortcut WHERE id = ?`, delete.ID)
if err != nil {
return FormatError(err)
}

View file

@ -1,13 +1,20 @@
package store
import (
"database/sql"
"memos/common"
)
// Store provides database access to all raw objects
type Store struct {
db *DB
db *sql.DB
profile *common.Profile
}
// New creates a new instance of Store
func New(db *DB) *Store {
func New(db *sql.DB, profile *common.Profile) *Store {
return &Store{
db: db,
db: db,
profile: profile,
}
}

View file

@ -1,6 +1,7 @@
package store
import (
"database/sql"
"fmt"
"memos/api"
"memos/common"
@ -94,8 +95,8 @@ func (s *Store) FindUser(find *api.UserFind) (*api.User, error) {
return user, nil
}
func createUser(db *DB, create *api.UserCreate) (*userRaw, error) {
row, err := db.Db.Query(`
func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) {
row, err := db.Query(`
INSERT INTO user (
email,
role,
@ -135,7 +136,7 @@ func createUser(db *DB, create *api.UserCreate) (*userRaw, error) {
return &userRaw, nil
}
func patchUser(db *DB, patch *api.UserPatch) (*userRaw, error) {
func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) {
set, args := []string{}, []interface{}{}
if v := patch.RowStatus; v != nil {
@ -156,7 +157,7 @@ func patchUser(db *DB, patch *api.UserPatch) (*userRaw, error) {
args = append(args, patch.ID)
row, err := db.Db.Query(`
row, err := db.Query(`
UPDATE user
SET `+strings.Join(set, ", ")+`
WHERE id = ?
@ -188,7 +189,7 @@ func patchUser(db *DB, patch *api.UserPatch) (*userRaw, error) {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)}
}
func findUserList(db *DB, find *api.UserFind) ([]*userRaw, error) {
func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if v := find.ID; v != nil {
@ -207,7 +208,7 @@ func findUserList(db *DB, find *api.UserFind) ([]*userRaw, error) {
where, args = append(where, "open_id = ?"), append(args, *v)
}
rows, err := db.Db.Query(`
rows, err := db.Query(`
SELECT
id,
email,