From ab2c86640b904f7903d12b759a14d2f1388d211d Mon Sep 17 00:00:00 2001 From: Steven Date: Mon, 9 Oct 2023 23:10:41 +0800 Subject: [PATCH] chore: move rate limiter to apiv1 --- api/v1/v1.go | 19 +++++++++++++++++++ server/server.go | 17 ----------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/api/v1/v1.go b/api/v1/v1.go index 76c87dfd..78d6b1ac 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -1,7 +1,11 @@ package v1 import ( + "net/http" + "time" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" "github.com/usememos/memos/api/resource" "github.com/usememos/memos/plugin/telegram" @@ -45,6 +49,21 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { // Register API v1 routes. apiV1Group := rootGroup.Group("/api/v1") + apiV1Group.Use(middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ + Store: middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 30, Burst: 100, ExpiresIn: 3 * time.Minute}, + ), + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return context.JSON(http.StatusForbidden, nil) + }, + DenyHandler: func(context echo.Context, identifier string, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + })) apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return JWTMiddleware(s, next, s.Secret) }) diff --git a/server/server.go b/server/server.go index 09bff45b..d8a6f35e 100644 --- a/server/server.go +++ b/server/server.go @@ -77,23 +77,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store Timeout: 30 * time.Second, })) - e.Use(middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ - Skipper: grpcRequestSkipper, - Store: middleware.NewRateLimiterMemoryStoreWithConfig( - middleware.RateLimiterMemoryStoreConfig{Rate: 30, Burst: 100, ExpiresIn: 3 * time.Minute}, - ), - IdentifierExtractor: func(ctx echo.Context) (string, error) { - id := ctx.RealIP() - return id, nil - }, - ErrorHandler: func(context echo.Context, err error) error { - return context.JSON(http.StatusForbidden, nil) - }, - DenyHandler: func(context echo.Context, identifier string, err error) error { - return context.JSON(http.StatusTooManyRequests, nil) - }, - })) - serverID, err := s.getSystemServerID(ctx) if err != nil { return nil, errors.Wrap(err, "failed to retrieve system server ID")