From d0ddac296f5599c470d65ccc2de5346c1d268c88 Mon Sep 17 00:00:00 2001 From: boojack Date: Thu, 6 Apr 2023 07:42:39 +0800 Subject: [PATCH] chore: update store error handler (#1479) --- server/jwt.go | 19 +++++++++---------- server/rss.go | 25 +++++++------------------ 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/server/jwt.go b/server/jwt.go index f8d0f120..5cb6e252 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -1,7 +1,6 @@ package server import ( - "errors" "fmt" "net/http" "strconv" @@ -10,7 +9,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" - pkgerrors "github.com/pkg/errors" + "github.com/pkg/errors" "github.com/usememos/memos/api" "github.com/usememos/memos/common" "github.com/usememos/memos/server/auth" @@ -38,7 +37,7 @@ func getUserIDContextKey() string { func GenerateTokensAndSetCookies(c echo.Context, user *api.User, mode string, secret string) error { accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, mode, secret) if err != nil { - return pkgerrors.Wrap(err, "failed to generate access token") + return errors.Wrap(err, "failed to generate access token") } cookieExp := time.Now().Add(auth.CookieExpDuration) @@ -47,7 +46,7 @@ func GenerateTokensAndSetCookies(c echo.Context, user *api.User, mode string, se // We generate here a new refresh token and saving it to the cookie. refreshToken, err := auth.GenerateRefreshToken(user.Username, user.ID, mode, secret) if err != nil { - return pkgerrors.Wrap(err, "failed to generate refresh token") + return errors.Wrap(err, "failed to generate refresh token") } setTokenCookie(c, auth.RefreshTokenCookieName, refreshToken, cookieExp) @@ -116,7 +115,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha } // Skip validation for server status endpoints. - if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/idp", "/api/user/:id") && method == http.MethodGet { + if common.HasPrefixes(path, "/api/ping", "/api/idp", "/api/user/:id") && method == http.MethodGet { return next(c) } @@ -127,7 +126,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha return next(c) } // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. - if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet { + if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { return next(c) } return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") @@ -136,14 +135,14 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha claims := &Claims{} accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, pkgerrors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) } if kid, ok := t.Header["kid"].(string); ok { if kid == "v1" { return []byte(secret), nil } } - return nil, pkgerrors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) }) if !audienceContains(claims.Audience, fmt.Sprintf(auth.AccessTokenAudienceFmt, mode)) { @@ -202,7 +201,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha refreshTokenClaims := &Claims{} refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, pkgerrors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) + return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) } if kid, ok := t.Header["kid"].(string); ok { @@ -210,7 +209,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha return []byte(secret), nil } } - return nil, pkgerrors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) + return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) }) if err != nil { if err == jwt.ErrSignatureInvalid { diff --git a/server/rss.go b/server/rss.go index 872f27c5..9da775ad 100644 --- a/server/rss.go +++ b/server/rss.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/feeds" "github.com/labstack/echo/v4" "github.com/usememos/memos/api" + "github.com/usememos/memos/common" ) func (s *Server) registerRSSRoutes(g *echo.Group) { @@ -92,13 +93,10 @@ func generateRSSFromMemoList(memoList []*api.Memo, baseURL string, profile *api. Created: time.Now(), } - var itemCountLimit = min(len(memoList), MaxRSSItemCount) - + var itemCountLimit = common.Min(len(memoList), MaxRSSItemCount) feed.Items = make([]*feeds.Item, itemCountLimit) - for i := 0; i < itemCountLimit; i++ { memo := memoList[i] - feed.Items[i] = &feeds.Item{ Title: getRSSItemTitle(memo.Content), Link: &feeds.Link{Href: baseURL + "/m/" + strconv.Itoa(memo.ID)}, @@ -126,31 +124,22 @@ func getSystemCustomizedProfile(ctx context.Context, s *Server) (*api.Customized systemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ Name: api.SystemSettingCustomizedProfileName, }) - if err != nil { - return customizedProfile, err + if err != nil && common.ErrorCode(err) != common.NotFound { + return nil, err } - - err = json.Unmarshal([]byte(systemSetting.Value), customizedProfile) - if err != nil { - return customizedProfile, err + if err := json.Unmarshal([]byte(systemSetting.Value), customizedProfile); err != nil { + return nil, err } return customizedProfile, nil } -func min(a, b int) int { - if a < b { - return a - } - return b -} - func getRSSItemTitle(content string) string { var title string if isTitleDefined(content) { title = strings.Split(content, "\n")[0][2:] } else { title = strings.Split(content, "\n")[0] - var titleLengthLimit = min(len(title), MaxRSSItemTitleLength) + var titleLengthLimit = common.Min(len(title), MaxRSSItemTitleLength) if titleLengthLimit < len(title) { title = title[:titleLengthLimit] + "..." }