From a797280e3fdf62a3b27d3bfc7186b10d49e98e37 Mon Sep 17 00:00:00 2001 From: boojack Date: Sun, 1 Jan 2023 23:26:21 +0800 Subject: [PATCH] chore: update middleware skipper (#887) * chore: update middleware skipper * chore: update --- .github/workflows/backend-tests-default.yml | 20 ------ .github/workflows/backend-tests.yml | 2 - .github/workflows/frontend-tests-default.yml | 25 -------- .github/workflows/frontend-tests.yml | 2 - api/auth.go | 4 +- server/acl.go | 67 ++++++-------------- server/auth.go | 10 +-- server/common.go | 45 +++++++++++-- server/server.go | 41 ++---------- server/shortcut.go | 2 +- server/tag.go | 16 ++--- web/src/helpers/api.ts | 2 +- web/src/pages/Auth.tsx | 4 +- 13 files changed, 83 insertions(+), 157 deletions(-) delete mode 100644 .github/workflows/backend-tests-default.yml delete mode 100644 .github/workflows/frontend-tests-default.yml diff --git a/.github/workflows/backend-tests-default.yml b/.github/workflows/backend-tests-default.yml deleted file mode 100644 index 02bc4f3b..00000000 --- a/.github/workflows/backend-tests-default.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: Default Backend Test - -on: - pull_request: - branches: - - main - - "release/*.*.*" - paths: - - "web/**" - -jobs: - go-static-checks: - runs-on: ubuntu-latest - steps: - - run: 'echo "Not required"' - - go-tests: - runs-on: ubuntu-latest - steps: - - run: 'echo "Not required"' diff --git a/.github/workflows/backend-tests.yml b/.github/workflows/backend-tests.yml index 28beedc5..ce87f186 100644 --- a/.github/workflows/backend-tests.yml +++ b/.github/workflows/backend-tests.yml @@ -5,8 +5,6 @@ on: branches: - main - "release/*.*.*" - paths-ignore: - - "web/**" jobs: go-static-checks: diff --git a/.github/workflows/frontend-tests-default.yml b/.github/workflows/frontend-tests-default.yml deleted file mode 100644 index caee7572..00000000 --- a/.github/workflows/frontend-tests-default.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Default Frontend Test - -on: - pull_request: - branches: - - main - - "release/*.*.*" - paths-ignore: - - "web/**" - -jobs: - eslint-checks: - runs-on: ubuntu-latest - steps: - - run: 'echo "Not required"' - - jest-tests: - runs-on: ubuntu-latest - steps: - - run: 'echo "Not required"' - - frontend-build: - runs-on: ubuntu-latest - steps: - - run: 'echo "Not required"' diff --git a/.github/workflows/frontend-tests.yml b/.github/workflows/frontend-tests.yml index e0ac5f95..76748f10 100644 --- a/.github/workflows/frontend-tests.yml +++ b/.github/workflows/frontend-tests.yml @@ -5,8 +5,6 @@ on: branches: - main - "release/*.*.*" - paths: - - "web/**" jobs: eslint-checks: diff --git a/api/auth.go b/api/auth.go index c44a5587..5ed4ebf1 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,11 +1,11 @@ package api -type Signin struct { +type SignIn struct { Username string `json:"username"` Password string `json:"password"` } -type Signup struct { +type SignUp struct { Username string `json:"username"` Password string `json:"password"` Role Role `json:"role"` diff --git a/server/acl.go b/server/acl.go index 87e5bad7..da93b9f1 100644 --- a/server/acl.go +++ b/server/acl.go @@ -15,6 +15,7 @@ import ( var ( userIDContextKey = "user-id" + sessionName = "memos_session" ) func getUserIDContextKey() string { @@ -22,7 +23,7 @@ func getUserIDContextKey() string { } func setUserSession(ctx echo.Context, user *api.User) error { - sess, _ := session.Get("memos_session", ctx) + sess, _ := session.Get(sessionName, ctx) sess.Options = &sessions.Options{ Path: "/", MaxAge: 3600 * 24 * 30, @@ -38,7 +39,7 @@ func setUserSession(ctx echo.Context, user *api.User) error { } func removeUserSession(ctx echo.Context) error { - sess, _ := session.Get("memos_session", ctx) + sess, _ := session.Get(sessionName, ctx) sess.Options = &sessions.Options{ Path: "/", MaxAge: 0, @@ -57,61 +58,33 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { ctx := c.Request().Context() path := c.Path() - // Skip auth. - if common.HasPrefixes(path, "/api/auth") { + if s.DefaultAuthSkipper(c) { return next(c) } - { - // If there is openId in query string and related user is found, then skip auth. - openID := c.QueryParam("openId") - if openID != "" { - userFind := &api.UserFind{ - OpenID: &openID, - } - user, err := s.Store.FindUser(ctx, userFind) - if err != nil && common.ErrorCode(err) != common.NotFound { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err) - } - if user != nil { - // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) - return next(c) + sess, _ := session.Get(sessionName, c) + userIDValue := sess.Values[userIDContextKey] + if userIDValue != nil { + userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) + userFind := &api.UserFind{ + ID: &userID, + } + user, err := s.Store.FindUser(ctx, userFind) + if err != nil && common.ErrorCode(err) != common.NotFound { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) + } + if user != nil { + if user.RowStatus == api.Archived { + return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username)) } + c.Set(getUserIDContextKey(), userID) } } - { - sess, _ := session.Get("memos_session", c) - userIDValue := sess.Values[userIDContextKey] - if userIDValue != nil { - userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) - userFind := &api.UserFind{ - ID: &userID, - } - user, err := s.Store.FindUser(ctx, userFind) - if err != nil && common.ErrorCode(err) != common.NotFound { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) - } - if user != nil { - if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username)) - } - c.Set(getUserIDContextKey(), userID) - } - } - } - - if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo/all", "/api/memo/:memoId", "/api/memo/amount") && c.Request().Method == http.MethodGet { + if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo") && c.Request().Method == http.MethodGet { return next(c) } - if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet { - if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - return next(c) - } - } - userID := c.Get(getUserIDContextKey()) if userID == nil { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") diff --git a/server/auth.go b/server/auth.go index f8124512..e92cb893 100644 --- a/server/auth.go +++ b/server/auth.go @@ -16,7 +16,7 @@ import ( func (s *Server) registerAuthRoutes(g *echo.Group) { g.POST("/auth/signin", func(c echo.Context) error { ctx := c.Request().Context() - signin := &api.Signin{} + signin := &api.SignIn{} if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) } @@ -56,7 +56,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { g.POST("/auth/signup", func(c echo.Context) error { ctx := c.Request().Context() - signup := &api.Signup{} + signup := &api.SignUp{} if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) } @@ -130,14 +130,14 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return nil }) - g.POST("/auth/logout", func(c echo.Context) error { + g.POST("/auth/signout", func(c echo.Context) error { ctx := c.Request().Context() err := removeUserSession(c) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err) + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set sign out session").SetInternal(err) } s.Collector.Collect(ctx, &metric.Metric{ - Name: "user logout", + Name: "user signout", }) return c.JSON(http.StatusOK, true) diff --git a/server/common.go b/server/common.go index 033bc5d6..9c3b1297 100644 --- a/server/common.go +++ b/server/common.go @@ -1,11 +1,46 @@ package server -func composeResponse(data interface{}) interface{} { - type R struct { - Data interface{} `json:"data"` - } +import ( + "github.com/labstack/echo/v4" + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" +) - return R{ +type response struct { + Data interface{} `json:"data"` +} + +func composeResponse(data interface{}) response { + return response{ Data: data, } } + +func (server *Server) DefaultAuthSkipper(c echo.Context) bool { + ctx := c.Request().Context() + path := c.Path() + + // Skip auth. + if common.HasPrefixes(path, "/api/auth") { + return true + } + + // If there is openId in query string and related user is found, then skip auth. + openID := c.QueryParam("openId") + if openID != "" { + userFind := &api.UserFind{ + OpenID: &openID, + } + user, err := server.Store.FindUser(ctx, userFind) + if err != nil && common.ErrorCode(err) != common.NotFound { + return false + } + if user != nil { + // Stores userID into context. + c.Set(getUserIDContextKey(), user.ID) + return true + } + } + + return false +} diff --git a/server/server.go b/server/server.go index 3be61e46..322063c5 100644 --- a/server/server.go +++ b/server/server.go @@ -4,8 +4,6 @@ import ( "fmt" "time" - "github.com/usememos/memos/api" - "github.com/usememos/memos/common" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -43,8 +41,12 @@ func NewServer(profile *profile.Profile) *Server { `"status":${status},"error":"${error}"}` + "\n", })) + e.Use(middleware.Gzip()) + e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ - Skipper: s.OpenAPISkipper, + Skipper: func(c echo.Context) bool { + return s.DefaultAuthSkipper(c) + }, TokenLookup: "cookie:_csrf", })) @@ -92,35 +94,6 @@ func NewServer(profile *profile.Profile) *Server { return s } -func (server *Server) Run() error { - return server.e.Start(fmt.Sprintf(":%d", server.Profile.Port)) -} - -func (server *Server) OpenAPISkipper(c echo.Context) bool { - ctx := c.Request().Context() - path := c.Path() - - // Skip auth. - if common.HasPrefixes(path, "/api/auth") { - return true - } - - // If there is openId in query string and related user is found, then skip auth. - openID := c.QueryParam("openId") - if openID != "" { - userFind := &api.UserFind{ - OpenID: &openID, - } - user, err := server.Store.FindUser(ctx, userFind) - if err != nil && common.ErrorCode(err) != common.NotFound { - return false - } - if user != nil { - // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) - return true - } - } - - return false +func (s *Server) Run() error { + return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port)) } diff --git a/server/shortcut.go b/server/shortcut.go index 23b875a1..4f78032a 100644 --- a/server/shortcut.go +++ b/server/shortcut.go @@ -91,10 +91,10 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut") } + shortcutFind := &api.ShortcutFind{ CreatorID: &userID, } - list, err := s.Store.FindShortcutList(ctx, shortcutFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) diff --git a/server/tag.go b/server/tag.go index fc47e7ad..d1ef45d9 100644 --- a/server/tag.go +++ b/server/tag.go @@ -6,7 +6,6 @@ import ( "net/http" "regexp" "sort" - "strconv" "github.com/usememos/memos/api" "github.com/usememos/memos/common" @@ -49,19 +48,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) { g.GET("/tag", func(c echo.Context) error { ctx := c.Request().Context() - tagFind := &api.TagFind{} - if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - tagFind.CreatorID = userID + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") } - if tagFind.CreatorID == 0 { - currentUserID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") - } - tagFind.CreatorID = currentUserID + tagFind := &api.TagFind{ + CreatorID: userID, } - tagList, err := s.Store.FindTagList(ctx, tagFind) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err) diff --git a/web/src/helpers/api.ts b/web/src/helpers/api.ts index d6529ee7..94097094 100644 --- a/web/src/helpers/api.ts +++ b/web/src/helpers/api.ts @@ -34,7 +34,7 @@ export function signup(username: string, password: string, role: UserRole) { } export function signout() { - return axios.post("/api/auth/logout"); + return axios.post("/api/auth/signout"); } export function createUser(userCreate: UserCreate) { diff --git a/web/src/pages/Auth.tsx b/web/src/pages/Auth.tsx index 100d326e..4d315b4f 100644 --- a/web/src/pages/Auth.tsx +++ b/web/src/pages/Auth.tsx @@ -51,7 +51,7 @@ const Auth = () => { globalStore.setAppearance(appearance); }; - const handleSigninBtnsClick = async () => { + const handleSignInBtnClick = async () => { if (actionBtnLoadingState.isLoading) { return; } @@ -153,7 +153,7 @@ const Auth = () => { / )} -