mirror of
https://github.com/usememos/memos.git
synced 2024-09-20 14:35:54 +08:00
fix: get&set session
This commit is contained in:
parent
d661134b03
commit
a8f0c9a7b1
|
@ -1,11 +1,11 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
type Login struct {
|
type Login struct {
|
||||||
Name string
|
Name string `jsonapi:"attr,name"`
|
||||||
Password string
|
Password string `jsonapi:"attr,password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Signup struct {
|
type Signup struct {
|
||||||
Name string
|
Name string `jsonapi:"attr,name"`
|
||||||
Password string
|
Password string `jsonapi:"attr,password"`
|
||||||
}
|
}
|
||||||
|
|
12
api/user.go
12
api/user.go
|
@ -5,25 +5,25 @@ type User struct {
|
||||||
CreatedTs int64 `jsonapi:"attr,createdTs"`
|
CreatedTs int64 `jsonapi:"attr,createdTs"`
|
||||||
UpdatedTs int64 `jsonapi:"attr,updatedTs"`
|
UpdatedTs int64 `jsonapi:"attr,updatedTs"`
|
||||||
|
|
||||||
|
OpenId string `jsonapi:"attr,openId"`
|
||||||
Name string `jsonapi:"attr,name"`
|
Name string `jsonapi:"attr,name"`
|
||||||
Password string
|
Password string
|
||||||
OpenId string `jsonapi:"attr,openId"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserCreate struct {
|
type UserCreate struct {
|
||||||
|
OpenId string `jsonapi:"attr,openId"`
|
||||||
Name string `jsonapi:"attr,name"`
|
Name string `jsonapi:"attr,name"`
|
||||||
Password string `jsonapi:"attr,password"`
|
Password string `jsonapi:"attr,password"`
|
||||||
OpenId string `jsonapi:"attr,openId"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserPatch struct {
|
type UserPatch struct {
|
||||||
Id int
|
Id int
|
||||||
|
|
||||||
Name *string `jsonapi:"attr,name"`
|
OpenId *string
|
||||||
Password *string `jsonapi:"attr,password"`
|
|
||||||
OpenId *string
|
|
||||||
|
|
||||||
ResetOpenId *bool `jsonapi:"attr,resetOpenId"`
|
Name *string `jsonapi:"attr,name"`
|
||||||
|
Password *string `jsonapi:"attr,password"`
|
||||||
|
ResetOpenId *bool `jsonapi:"attr,resetOpenId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserFind struct {
|
type UserFind struct {
|
||||||
|
|
|
@ -34,26 +34,31 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err)
|
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = setUserSession(c, user)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set login session").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
||||||
if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil {
|
if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setUserSession(c, user)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
g.POST("/auth/logout", func(c echo.Context) error {
|
g.POST("/auth/logout", func(c echo.Context) error {
|
||||||
removeUserSession(c)
|
err := removeUserSession(c)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
|
||||||
c.Response().WriteHeader(http.StatusOK)
|
c.Response().WriteHeader(http.StatusOK)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
g.POST("/auth/signup", func(c echo.Context) error {
|
g.POST("/auth/signup", func(c echo.Context) error {
|
||||||
signup := &api.Signup{}
|
signup := &api.Signup{}
|
||||||
if err := jsonapi.UnmarshalPayload(c.Request().Body, signup); err != nil {
|
if err := jsonapi.UnmarshalPayload(c.Request().Body, signup); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted login request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userFind := &api.UserFind{
|
userFind := &api.UserFind{
|
||||||
|
@ -77,12 +82,16 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = setUserSession(c, user)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signup session").SetInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
|
||||||
if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil {
|
if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setUserSession(c, user)
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,33 +21,49 @@ func getUserIdContextKey() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purpose of this cookie is to store the user's id.
|
// Purpose of this cookie is to store the user's id.
|
||||||
func setUserSession(c echo.Context, user *api.User) {
|
func setUserSession(c echo.Context, user *api.User) error {
|
||||||
sess, _ := session.Get("session", c)
|
sess, err := session.Get("session", c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get session")
|
||||||
|
}
|
||||||
sess.Options = &sessions.Options{
|
sess.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: 1000 * 3600 * 24 * 30,
|
MaxAge: 1000 * 3600 * 24 * 30,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
}
|
}
|
||||||
sess.Values[userIdContextKey] = strconv.Itoa(user.Id)
|
sess.Values[userIdContextKey] = user.Id
|
||||||
sess.Save(c.Request(), c.Response())
|
err = sess.Save(c.Request(), c.Response())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeUserSession(c echo.Context) {
|
func removeUserSession(c echo.Context) error {
|
||||||
sess, _ := session.Get("session", c)
|
sess, err := session.Get("session", c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get session")
|
||||||
|
}
|
||||||
sess.Options = &sessions.Options{
|
sess.Options = &sessions.Options{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: 0,
|
MaxAge: 0,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
}
|
}
|
||||||
sess.Values[userIdContextKey] = nil
|
sess.Values[userIdContextKey] = nil
|
||||||
sess.Save(c.Request(), c.Response())
|
err = sess.Save(c.Request(), c.Response())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use session instead of jwt in the initial version
|
// Use session instead of jwt in the initial version
|
||||||
func JWTMiddleware(us api.UserService, next echo.HandlerFunc) echo.HandlerFunc {
|
func JWTMiddleware(us api.UserService, next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
// Skips auth, test
|
// Skips auth
|
||||||
if common.HasPrefixes(c.Path(), "/api/auth", "/api/test") {
|
if common.HasPrefixes(c.Path(), "/api/auth") {
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +71,13 @@ func JWTMiddleware(us api.UserService, next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing session")
|
||||||
}
|
}
|
||||||
userId, err := strconv.Atoi(fmt.Sprintf("%v", sess.Values[userIdContextKey]))
|
|
||||||
|
userIdValue := sess.Values[userIdContextKey]
|
||||||
|
if userIdValue == nil {
|
||||||
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userId in session")
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, err := strconv.Atoi(fmt.Sprintf("%v", userIdValue))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Failed to malformatted user id in the session.")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Failed to malformatted user id in the session.")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"memos/api"
|
"memos/api"
|
||||||
|
"memos/common"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/labstack/echo-contrib/session"
|
"github.com/labstack/echo-contrib/session"
|
||||||
|
@ -33,7 +34,7 @@ func NewServer() *Server {
|
||||||
HTML5: true,
|
HTML5: true,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
e.Use(session.Middleware(sessions.NewCookieStore([]byte("secret"))))
|
e.Use(session.Middleware(sessions.NewCookieStore([]byte(common.GenUUID()))))
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
e: e,
|
e: e,
|
||||||
|
|
|
@ -124,7 +124,7 @@ func patchUser(db *DB, patch *api.UserPatch) (*api.User, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) {
|
func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) {
|
||||||
where, args := []string{}, []interface{}{}
|
where, args := []string{"1 = 1"}, []interface{}{}
|
||||||
|
|
||||||
if v := find.Id; v != nil {
|
if v := find.Id; v != nil {
|
||||||
where, args = append(where, "id = ?"), append(args, *v)
|
where, args = append(where, "id = ?"), append(args, *v)
|
||||||
|
@ -142,7 +142,7 @@ func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) {
|
||||||
name,
|
name,
|
||||||
password,
|
password,
|
||||||
open_id,
|
open_id,
|
||||||
created_ts
|
created_ts,
|
||||||
updated_ts
|
updated_ts
|
||||||
FROM user
|
FROM user
|
||||||
WHERE `+strings.Join(where, " AND "),
|
WHERE `+strings.Join(where, " AND "),
|
||||||
|
@ -164,6 +164,7 @@ func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) {
|
||||||
&user.CreatedTs,
|
&user.CreatedTs,
|
||||||
&user.UpdatedTs,
|
&user.UpdatedTs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
return nil, FormatError(err)
|
return nil, FormatError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue