diff --git a/api/v1/auth.go b/api/v1/auth.go index 11e881d0..dfa7339c 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -32,7 +32,7 @@ type SignUp struct { Password string `json:"password"` } -func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { +func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { g.POST("/auth/signin", func(c echo.Context) error { ctx := c.Request().Context() signin := &SignIn{} @@ -58,7 +58,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again") } - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { @@ -144,7 +144,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier)) } - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { @@ -208,7 +208,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + if err := auth.GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createAuthSignUpActivity(c, user); err != nil { diff --git a/api/v1/idp.go b/api/v1/idp.go index 03aa5e21..89b6ecdf 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -45,14 +45,14 @@ type IdentityProvider struct { Config *IdentityProviderConfig `json:"config"` } -type IdentityProviderCreate struct { +type CreateIdentityProviderRequest struct { Name string `json:"name"` Type IdentityProviderType `json:"type"` IdentifierFilter string `json:"identifierFilter"` Config *IdentityProviderConfig `json:"config"` } -type IdentityProviderPatch struct { +type UpdateIdentityProviderRequest struct { ID int Type IdentityProviderType `json:"type"` Name *string `json:"name"` @@ -78,7 +78,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } - identityProviderCreate := &IdentityProviderCreate{} + identityProviderCreate := &CreateIdentityProviderRequest{} if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err) } @@ -117,7 +117,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) } - identityProviderPatch := &IdentityProviderPatch{ + identityProviderPatch := &UpdateIdentityProviderRequest{ ID: identityProviderID, } if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil { diff --git a/api/v1/jwt.go b/api/v1/jwt.go new file mode 100644 index 00000000..6551e039 --- /dev/null +++ b/api/v1/jwt.go @@ -0,0 +1,239 @@ +package v1 + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" + "github.com/usememos/memos/common" + "github.com/usememos/memos/server/auth" + "github.com/usememos/memos/store" +) + +const ( + // Context section + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + userIDContextKey = "user-id" +) + +func getUserIDContextKey() string { + return userIDContextKey +} + +// Claims creates a struct that will be encoded to a JWT. +// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. +type Claims struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +func extractTokenFromHeader(c echo.Context) (string, error) { + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +func findAccessToken(c echo.Context) string { + accessToken := "" + cookie, _ := c.Cookie(auth.AccessTokenCookieName) + if cookie != nil { + accessToken = cookie.Value + } + if accessToken == "" { + accessToken, _ = extractTokenFromHeader(c) + } + + return accessToken +} + +func audienceContains(audience jwt.ClaimStrings, token string) bool { + for _, v := range audience { + if v == token { + return true + } + } + return false +} + +// JWTMiddleware validates the access token. +// If the access token is about to expire or has expired and the request has a valid refresh token, it +// will try to generate new access token and refresh token. +func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Request().URL.Path + method := c.Request().Method + + if server.defaultAuthSkipper(c) { + return next(c) + } + + // Skip validation for server status endpoints. + if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet { + return next(c) + } + + token := findAccessToken(c) + if token == "" { + // Allow the user to access the public endpoints. + if common.HasPrefixes(path, "/o") { + return next(c) + } + // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. + if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { + return next(c) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") + } + + claims := &Claims{} + accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + + if !accessToken.Valid { + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") + } + + if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) + } + generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration + if err != nil { + var ve *jwt.ValidationError + if errors.As(err, &ve) { + // If expiration error is the only error, we will clear the err + // and generate new access token and refresh token + if ve.Errors == jwt.ValidationErrorExpired { + generateToken = true + } + } else { + return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) + } + } + + // We either have a valid access token or we will attempt to generate new access token and refresh token + ctx := c.Request().Context() + userID, err := strconv.Atoi(claims.Subject) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") + } + + // Even if there is no error, we still need to make sure the user still exists. + user, err := server.Store.GetUser(ctx, &store.FindUserMessage{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) + } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) + } + + if generateToken { + generateTokenFunc := func() error { + rc, err := c.Cookie(auth.RefreshTokenCookieName) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") + } + + // Parses token and checks if it's valid. + refreshTokenClaims := &Claims{} + refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) + } + + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) + }) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + + if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { + return echo.NewHTTPError(http.StatusUnauthorized, + fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + refreshTokenClaims.Audience, + auth.RefreshTokenAudienceName, + )) + } + + // If we have a valid refresh token, we will generate new access token and refresh token + if refreshToken != nil && refreshToken.Valid { + if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + } + + return nil + } + + // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token + // In such case, we won't return the error. + if err := generateTokenFunc(); err != nil && !accessToken.Valid { + return err + } + } + + // Stores userID into context. + c.Set(getUserIDContextKey(), userID) + return next(c) + } +} + +func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { + ctx := c.Request().Context() + path := c.Path() + + // Skip auth. + if common.HasPrefixes(path, "/api/v1/auth") { + return true + } + + // If there is openId in query string and related user is found, then skip auth. + openID := c.QueryParam("openId") + if openID != "" { + user, err := s.Store.GetUser(ctx, &store.FindUserMessage{ + OpenID: &openID, + }) + if err != nil && common.ErrorCode(err) != common.NotFound { + return false + } + if user != nil { + // Stores userID into context. + c.Set(getUserIDContextKey(), user.ID) + return true + } + } + + return false +} diff --git a/api/v1/v1.go b/api/v1/v1.go index 9287c327..1875031d 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -6,17 +6,6 @@ import ( "github.com/usememos/memos/store" ) -const ( - // Context section - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - userIDContextKey = "user-id" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - type APIV1Service struct { Secret string Profile *profile.Profile @@ -31,8 +20,12 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store } } -func (s *APIV1Service) Register(apiV1Group *echo.Group) { +func (s *APIV1Service) Register(rootGroup *echo.Group) { + apiV1Group := rootGroup.Group("/api/v1") + apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return JWTMiddleware(s, next, s.Secret) + }) s.registerTestRoutes(apiV1Group) - s.registerAuthRoutes(apiV1Group, s.Secret) + s.registerAuthRoutes(apiV1Group) s.registerIdentityProviderRoutes(apiV1Group) } diff --git a/server/server.go b/server/server.go index 15facd4b..159629d2 100644 --- a/server/server.go +++ b/server/server.go @@ -108,12 +108,8 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store s.registerOpenAIRoutes(apiGroup) s.registerMemoRelationRoutes(apiGroup) - apiV1Group := e.Group("/api/v1") - apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return JWTMiddleware(s, next, s.Secret) - }) apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store) - apiV1Service.Register(apiV1Group) + apiV1Service.Register(rootGroup) return s, nil }