diff --git a/api/auth.go b/api/auth.go index 70c9bbf53..7d69d15f7 100644 --- a/api/auth.go +++ b/api/auth.go @@ -30,13 +30,10 @@ func handleUserSignUp(w http.ResponseWriter, r *http.Request) { return } - userIdCookie := &http.Cookie{ - Name: "user_id", - Value: user.Id, - Path: "/", - MaxAge: 3600 * 24 * 30, - } - http.SetCookie(w, userIdCookie) + session, _ := SessionStore.Get(r, "session") + + session.Values["user_id"] = user.Id + session.Save(r, w) json.NewEncoder(w).Encode(Response{ Succeed: true, @@ -66,13 +63,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) { return } - userIdCookie := &http.Cookie{ - Name: "user_id", - Value: user.Id, - Path: "/", - MaxAge: 3600 * 24 * 30, - } - http.SetCookie(w, userIdCookie) + session, _ := SessionStore.Get(r, "session") + + session.Values["user_id"] = user.Id + session.Save(r, w) json.NewEncoder(w).Encode(Response{ Succeed: true, @@ -82,13 +76,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) { } func handleUserSignOut(w http.ResponseWriter, r *http.Request) { - userIdCookie := &http.Cookie{ - Name: "user_id", - Value: "", - Path: "/", - MaxAge: 0, - } - http.SetCookie(w, userIdCookie) + session, _ := SessionStore.Get(r, "session") + + session.Values["user_id"] = "" + session.Save(r, w) json.NewEncoder(w).Encode(Response{ Succeed: true, diff --git a/api/memo.go b/api/memo.go index d37534ecf..09be5d497 100644 --- a/api/memo.go +++ b/api/memo.go @@ -10,7 +10,7 @@ import ( ) func handleGetMyMemos(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) urlParams := r.URL.Query() deleted := urlParams.Get("deleted") onlyDeletedFlag := deleted == "true" @@ -34,7 +34,7 @@ type CreateMemo struct { } func handleCreateMemo(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) createMemo := CreateMemo{} err := json.NewDecoder(r.Body).Decode(&createMemo) @@ -105,6 +105,8 @@ func handleDeleteMemo(w http.ResponseWriter, r *http.Request) { func RegisterMemoRoutes(r *mux.Router) { memoRouter := r.PathPrefix("/api/memo").Subrouter() + memoRouter.Use(AuthCheckerMiddleWare) + memoRouter.HandleFunc("/all", handleGetMyMemos).Methods("GET") memoRouter.HandleFunc("/", handleCreateMemo).Methods("PUT") memoRouter.HandleFunc("/{id}", handleUpdateMemo).Methods("PATCH") diff --git a/api/middlewares.go b/api/middlewares.go index ca37d4432..270b61190 100644 --- a/api/middlewares.go +++ b/api/middlewares.go @@ -7,9 +7,9 @@ import ( func AuthCheckerMiddleWare(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userId, err := GetUserIdInCookie(r) + session, _ := SessionStore.Get(r, "session") - if err != nil || userId == "" { + if userId, ok := session.Values["user_id"].(string); !ok || userId == "" { e.ErrorHandler(w, "NOT_AUTH", "Need authorize") return } diff --git a/api/query.go b/api/query.go index 6a20b36ca..546dd6efb 100644 --- a/api/query.go +++ b/api/query.go @@ -10,7 +10,7 @@ import ( ) func handleGetMyQueries(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) queries, err := store.GetQueriesByUserId(userId) @@ -32,7 +32,7 @@ type QueryPut struct { } func handleCreateQuery(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) queryPut := QueryPut{} err := json.NewDecoder(r.Body).Decode(&queryPut) @@ -103,6 +103,8 @@ func handleDeleteQuery(w http.ResponseWriter, r *http.Request) { func RegisterQueryRoutes(r *mux.Router) { queryRouter := r.PathPrefix("/api/query").Subrouter() + queryRouter.Use(AuthCheckerMiddleWare) + queryRouter.HandleFunc("/all", handleGetMyQueries).Methods("GET") queryRouter.HandleFunc("/", handleCreateQuery).Methods("PUT") queryRouter.HandleFunc("/{id}", handleUpdateQuery).Methods("PATCH") diff --git a/api/session.go b/api/session.go new file mode 100644 index 000000000..c4d805153 --- /dev/null +++ b/api/session.go @@ -0,0 +1,9 @@ +package api + +import ( + "memos/common" + + "github.com/gorilla/sessions" +) + +var SessionStore = sessions.NewCookieStore([]byte(common.GenUUID())) diff --git a/api/user.go b/api/user.go index f627ccd8a..61da75a9e 100644 --- a/api/user.go +++ b/api/user.go @@ -10,7 +10,7 @@ import ( ) func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) user, err := store.GetUserById(userId) @@ -27,7 +27,7 @@ func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) { } func handleUpdateMyUserInfo(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) userPatch := store.UserPatch{} err := json.NewDecoder(r.Body).Decode(&userPatch) @@ -83,7 +83,7 @@ type ValidPassword struct { } func handleValidPassword(w http.ResponseWriter, r *http.Request) { - userId, _ := GetUserIdInCookie(r) + userId, _ := GetUserIdInSession(r) validPassword := ValidPassword{} err := json.NewDecoder(r.Body).Decode(&validPassword) diff --git a/api/utils.go b/api/utils.go index 149a07357..a28c58666 100644 --- a/api/utils.go +++ b/api/utils.go @@ -10,12 +10,14 @@ type Response struct { Data interface{} `json:"data"` } -func GetUserIdInCookie(r *http.Request) (string, error) { - userIdCookie, err := r.Cookie("user_id") +func GetUserIdInSession(r *http.Request) (string, error) { + session, _ := SessionStore.Get(r, "session") - if err != nil { - return "", err + userId, ok := session.Values["user_id"].(string) + + if !ok { + return "", http.ErrNoCookie } - return userIdCookie.Value, err + return userId, nil } diff --git a/go.mod b/go.mod index 973cdcf32..03e4be15f 100644 --- a/go.mod +++ b/go.mod @@ -7,3 +7,8 @@ require github.com/gorilla/mux v1.8.0 require github.com/mattn/go-sqlite3 v1.14.9 require github.com/google/uuid v1.3.0 + +require ( + github.com/gorilla/securecookie v1.1.1 // indirect + github.com/gorilla/sessions v1.2.1 +) diff --git a/go.sum b/go.sum index fda9a0098..b87722be2 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,9 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= diff --git a/resources/initial_db.sql b/resources/initial_db.sql index a34fe7811..3bc24828a 100644 --- a/resources/initial_db.sql +++ b/resources/initial_db.sql @@ -1,26 +1,41 @@ /* - * Re-create tables and insert initial data(todo) + * Re-create tables and insert initial data */ +DROP TABLE IF EXISTS `users`; CREATE TABLE `users` ( `id` TEXT NOT NULL PRIMARY KEY, `username` TEXT NOT NULL, `password` TEXT NOT NULL, - `github_name` TEXT NULL DEFAULT '', - `wx_open_id` TEXT NULL DEFAULT '', + `github_name` TEXT DEFAULT '', + `wx_open_id` TEXT DEFAULT '', `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ); +INSERT INTO `users` + (`id`, `username`, `password`) +VALUES + ('0', 'admin', '123456'), + ('1', 'guest', '123456'); + +DROP TABLE IF EXISTS `memos`; CREATE TABLE `memos` ( `id` TEXT NOT NULL PRIMARY KEY, `content` TEXT NOT NULL, `user_id` TEXT NOT NULL, `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - `deleted_at` TEXT, + `deleted_at` TEXT DEFAULT '', FOREIGN KEY(`user_id`) REFERENCES `users`(`id`) ); +INSERT INTO `memos` + (`id`, `content`, `user_id`, ) +VALUES + ('0', '👋 Welcome to memos', '0'), + ('1', '👋 Welcome to memos', '1'); + +DROP TABLE IF EXISTS `queries`; CREATE TABLE `queries` ( `id` TEXT NOT NULL PRIMARY KEY, `user_id` TEXT NOT NULL, @@ -28,6 +43,6 @@ CREATE TABLE `queries` ( `querystring` TEXT NOT NULL, `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - `pinned_at` TEXT NULL, + `pinned_at` TEXT DEFAULT '', FOREIGN KEY(`user_id`) REFERENCES `users`(`id`) ); diff --git a/resources/memos.db b/resources/memos.db index ef548529e..e420b8d2c 100644 Binary files a/resources/memos.db and b/resources/memos.db differ diff --git a/store/db.go b/store/db.go new file mode 100644 index 000000000..4f76bf7a3 --- /dev/null +++ b/store/db.go @@ -0,0 +1,50 @@ +package store + +import ( + "database/sql" + "os" + + _ "github.com/mattn/go-sqlite3" +) + +/* + * Use a global variable to save the db connection: Quick and easy to setup. + * Reference: https://techinscribed.com/different-approaches-to-pass-database-connection-into-controllers-in-golang/ + */ +var DB *sql.DB + +func InitDBConn() { + dbFilePath := "/data/memos.db" + + if _, err := os.Stat(dbFilePath); err != nil { + dbFilePath = "./resources/memos.db" + resetDataInDefaultDatabase() + println("use the default database") + } else { + println("use the custom database") + } + + db, err := sql.Open("sqlite3", dbFilePath) + + if err != nil { + println("connect failed") + } else { + DB = db + println("connect to sqlite succeed") + } +} + +func FormatDBError(err error) error { + if err == nil { + return nil + } + + switch err.Error() { + default: + return err + } +} + +func resetDataInDefaultDatabase() { + // do nth +} diff --git a/store/memo.go b/store/memo.go index 2034f7e4b..8bcb4e7b4 100644 --- a/store/memo.go +++ b/store/memo.go @@ -65,34 +65,36 @@ func DeleteMemo(memoId string) (error, error) { } func GetMemoById(id string) (Memo, error) { - query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE id=?` + query := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE id=?` memo := Memo{} - err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) + err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) return memo, err } -func GetMemosByUserId(userId string, deleted bool) ([]Memo, error) { - query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE user_id=?` +func GetMemosByUserId(userId string, onlyDeleted bool) ([]Memo, error) { + sqlQuery := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE user_id=?` - if deleted { - query = query + ` AND deleted_at!=""` + if onlyDeleted { + sqlQuery = sqlQuery + ` AND deleted_at!=""` } else { - query = query + ` AND deleted_at=""` + sqlQuery = sqlQuery + ` AND deleted_at=""` } - rows, _ := DB.Query(query, userId) + rows, _ := DB.Query(sqlQuery, userId) defer rows.Close() memos := []Memo{} for rows.Next() { memo := Memo{} - rows.Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) + rows.Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) memos = append(memos, memo) } - err := rows.Err() + if err := rows.Err(); err != nil { + return nil, err + } - return memos, err + return memos, nil } diff --git a/store/query.go b/store/query.go index 1200365df..97e0163d4 100644 --- a/store/query.go +++ b/store/query.go @@ -27,8 +27,8 @@ func CreateNewQuery(title string, querystring string, userId string) (Query, err UpdatedAt: nowDateTimeStr, } - query := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)` - _, err := DB.Exec(query, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt) + sqlQuery := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)` + _, err := DB.Exec(sqlQuery, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt) return newQuery, err } @@ -72,14 +72,14 @@ func DeleteQuery(queryId string) (error, error) { } func GetQueryById(queryId string) (Query, error) { - sqlQuery := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE id=?` + sqlQuery := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE id=?` query := Query{} - err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) + err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) return query, err } func GetQueriesByUserId(userId string) ([]Query, error) { - query := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE user_id=?` + query := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE user_id=?` rows, _ := DB.Query(query, userId) defer rows.Close() @@ -88,12 +88,14 @@ func GetQueriesByUserId(userId string) ([]Query, error) { for rows.Next() { query := Query{} - rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) + rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) queries = append(queries, query) } - err := rows.Err() + if err := rows.Err(); err != nil { + return nil, err + } - return queries, err + return queries, nil } diff --git a/store/sqlite.go b/store/sqlite.go deleted file mode 100644 index 19fbc25fb..000000000 --- a/store/sqlite.go +++ /dev/null @@ -1,31 +0,0 @@ -package store - -import ( - "database/sql" - "fmt" - - _ "github.com/mattn/go-sqlite3" -) - -var DB *sql.DB - -func InitDBConn() { - db, err := sql.Open("sqlite3", "./resources/memos.db") - if err != nil { - fmt.Println("connect failed") - } else { - DB = db - fmt.Println("connect to sqlite succeed") - } -} - -func FormatDBError(err error) error { - if err == nil { - return nil - } - - switch err.Error() { - default: - return err - } -}