diff --git a/api/auth/auth.go b/api/auth/auth.go index 8abb5d3b..907c005a 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -1,13 +1,13 @@ package auth import ( + "fmt" "time" + + "github.com/golang-jwt/jwt/v4" ) const ( - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - UserIDContextKey = "user-id" // issuer is the issuer of the jwt token. Issuer = "memos" // Signing key section. For now, this is only used for signing, not for verifying since we only @@ -23,3 +23,42 @@ const ( // AccessTokenCookieName is the cookie name of access token. AccessTokenCookieName = "memos.access-token" ) + +type ClaimsMessage struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +// GenerateAccessToken generates an access token. +// username is the email of the user. +func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret string) (string, error) { + return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// generateToken generates a jwt token. +func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) { + registeredClaims := jwt.RegisteredClaims{ + Issuer: Issuer, + Audience: jwt.ClaimStrings{audience}, + IssuedAt: jwt.NewNumericDate(time.Now()), + Subject: fmt.Sprint(userID), + } + if expirationTime.After(time.Now()) { + registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime) + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{ + Name: username, + RegisteredClaims: registeredClaims, + }) + token.Header["kid"] = KeyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/api/v1/auth.go b/api/v1/auth.go index 981a6ba5..ea8dee00 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -5,9 +5,11 @@ import ( "fmt" "net/http" "regexp" + "time" "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp/oauth2" @@ -94,12 +96,15 @@ func (s *APIV1Service) SignIn(c echo.Context) error { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again") } - if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) + accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) } + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) userMessage := convertUserFromStore(user) return c.JSON(http.StatusOK, userMessage) } @@ -213,12 +218,15 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier)) } - if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) + accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } if err := s.createAuthSignInActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) } + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) userMessage := convertUserFromStore(user) return c.JSON(http.StatusOK, userMessage) } @@ -304,13 +312,15 @@ func (s *APIV1Service) SignUp(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } - if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) + accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } if err := s.createAuthSignUpActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) } - + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) userMessage := convertUserFromStore(user) return c.JSON(http.StatusOK, userMessage) } @@ -358,3 +368,22 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User } return err } + +// RemoveTokensAndCookies removes the jwt token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} diff --git a/api/v1/idp.go b/api/v1/idp.go index 4780f396..433fdb0a 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) @@ -88,7 +87,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err) } - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) isHostUser := false if ok { user, err := s.Store.GetUser(ctx, &store.FindUser{ @@ -129,7 +128,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error { // @Router /api/v1/idp [POST] func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -177,7 +176,7 @@ func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error { // @Router /api/v1/idp/{idpId} [GET] func (s *APIV1Service) GetIdentityProvider(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -224,7 +223,7 @@ func (s *APIV1Service) GetIdentityProvider(c echo.Context) error { // @Router /api/v1/idp/{idpId} [DELETE] func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -266,7 +265,7 @@ func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error { // @Router /api/v1/idp/{idpId} [PATCH] func (s *APIV1Service) UpdateIdentityProvider(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 5192c6c3..57052992 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "strings" - "time" "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" @@ -14,75 +13,11 @@ import ( "github.com/usememos/memos/store" ) -type claimsMessage struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(username string, userID int32, secret string) (string, error) { - expirationTime := time.Now().Add(auth.AccessTokenDuration) - return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. -func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { - accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate access token") - } - - cookieExp := time.Now().Add(auth.CookieExpDuration) - setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) - return nil -} - -// RemoveTokensAndCookies removes the jwt token from the cookies. -func RemoveTokensAndCookies(c echo.Context) { - cookieExp := time.Now().Add(-1 * time.Hour) - setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) -} - -// setTokenCookie sets the token to the cookie. -func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { - cookie := new(http.Cookie) - cookie.Name = name - cookie.Value = token - cookie.Expires = expiration - cookie.Path = "/" - // Http-only helps mitigate the risk of client side script accessing the protected cookie. - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteStrictMode - c.SetCookie(cookie) -} - -// generateToken generates a jwt token. -func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &claimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{aud}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: auth.Issuer, - Subject: fmt.Sprintf("%d", userID), - }, - } - - // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token.Header["kid"] = auth.KeyID - - // Create the JWT string. - tokenString, err := token.SignedString(secret) - if err != nil { - return "", err - } - - return tokenString, nil -} +const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + userIDContextKey = "user-id" +) func extractTokenFromHeader(c echo.Context) (string, error) { authHeader := c.Request().Header.Get("Authorization") @@ -149,7 +84,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") } - claims := &claimsMessage{} + claims := &auth.ClaimsMessage{} _, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -163,7 +98,6 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e }) if err != nil { - RemoveTokensAndCookies(c) return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) } if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { @@ -188,7 +122,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } // Stores userID into context. - c.Set(auth.UserIDContextKey, userID) + c.Set(userIDContextKey, userID) return next(c) } } @@ -213,7 +147,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { } if user != nil { // Stores userID into context. - c.Set(auth.UserIDContextKey, user.ID) + c.Set(userIDContextKey, user.ID) return true } } diff --git a/api/v1/memo.go b/api/v1/memo.go index cf1c40d0..d61665ec 100644 --- a/api/v1/memo.go +++ b/api/v1/memo.go @@ -10,7 +10,6 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" @@ -156,7 +155,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error { } } - currentUserID, ok := c.Get(auth.UserIDContextKey).(int32) + currentUserID, ok := c.Get(userIDContextKey).(int32) if !ok { // Anonymous use should only fetch PUBLIC memos with specified user if findMemoMessage.CreatorID == nil { @@ -247,7 +246,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error { // - It's currently possible to create phantom resources and relations. Phantom relations will trigger backend 404's when fetching memo. func (s *APIV1Service) CreateMemo(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -407,7 +406,7 @@ func (s *APIV1Service) CreateMemo(c echo.Context) error { func (s *APIV1Service) GetAllMemos(c echo.Context) error { ctx := c.Request().Context() findMemoMessage := &store.FindMemo{} - _, ok := c.Get(auth.UserIDContextKey).(int32) + _, ok := c.Get(userIDContextKey).(int32) if !ok { findMemoMessage.VisibilityList = []store.Visibility{store.Public} } else { @@ -481,7 +480,7 @@ func (s *APIV1Service) GetMemoStats(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") } - currentUserID, ok := c.Get(auth.UserIDContextKey).(int32) + currentUserID, ok := c.Get(userIDContextKey).(int32) if !ok { findMemoMessage.VisibilityList = []store.Visibility{store.Public} } else { @@ -548,7 +547,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error { return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) } - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if memo.Visibility == store.Private { if !ok || memo.CreatorID != userID { return echo.NewHTTPError(http.StatusForbidden, "this memo is private only") @@ -580,7 +579,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error { // @Router /api/v1/memo/{memoId} [DELETE] func (s *APIV1Service) DeleteMemo(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -633,7 +632,7 @@ func (s *APIV1Service) DeleteMemo(c echo.Context) error { // - Passing 0 to createdTs and updatedTs will set them to 0 in the database, which is probably unwanted. func (s *APIV1Service) UpdateMemo(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/memo_organizer.go b/api/v1/memo_organizer.go index 7d872061..523efd13 100644 --- a/api/v1/memo_organizer.go +++ b/api/v1/memo_organizer.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) @@ -47,7 +46,7 @@ func (s *APIV1Service) CreateMemoOrganizer(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/memo_resource.go b/api/v1/memo_resource.go index f052e5dc..163c3052 100644 --- a/api/v1/memo_resource.go +++ b/api/v1/memo_resource.go @@ -7,7 +7,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) @@ -95,7 +94,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -145,7 +144,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error { // @Router /api/v1/memo/{memoId}/resource/{resourceId} [DELETE] func (s *APIV1Service) UnbindMemoResource(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/resource.go b/api/v1/resource.go index 38394fbe..ea17b1e7 100644 --- a/api/v1/resource.go +++ b/api/v1/resource.go @@ -20,7 +20,6 @@ import ( "github.com/disintegration/imaging" "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/storage/s3" @@ -105,7 +104,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) { // @Router /api/v1/resource [GET] func (s *APIV1Service) GetResourceList(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -145,7 +144,7 @@ func (s *APIV1Service) GetResourceList(c echo.Context) error { // @Router /api/v1/resource [POST] func (s *APIV1Service) CreateResource(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -197,7 +196,7 @@ func (s *APIV1Service) CreateResource(c echo.Context) error { // @Router /api/v1/resource/blob [POST] func (s *APIV1Service) UploadResource(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -270,7 +269,7 @@ func (s *APIV1Service) UploadResource(c echo.Context) error { // @Router /api/v1/resource/{resourceId} [DELETE] func (s *APIV1Service) DeleteResource(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -327,7 +326,7 @@ func (s *APIV1Service) DeleteResource(c echo.Context) error { // @Router /api/v1/resource/{resourceId} [PATCH] func (s *APIV1Service) UpdateResource(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -398,7 +397,7 @@ func (s *APIV1Service) streamResource(c echo.Context) error { } // Protected resource require a logined user - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if resourceVisibility == store.Protected && (!ok || userID <= 0) { return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err) } diff --git a/api/v1/storage.go b/api/v1/storage.go index e00ea65a..948528ae 100644 --- a/api/v1/storage.go +++ b/api/v1/storage.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) @@ -82,7 +81,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { // @Router /api/v1/storage [GET] func (s *APIV1Service) GetStorageList(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -129,7 +128,7 @@ func (s *APIV1Service) GetStorageList(c echo.Context) error { // @Router /api/v1/storage [POST] func (s *APIV1Service) CreateStorage(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -190,7 +189,7 @@ func (s *APIV1Service) CreateStorage(c echo.Context) error { // - error message "Storage service %d is using" probably should be "Storage service %d is in use". func (s *APIV1Service) DeleteStorage(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -246,7 +245,7 @@ func (s *APIV1Service) DeleteStorage(c echo.Context) error { // @Router /api/v1/storage/{storageId} [PATCH] func (s *APIV1Service) UpdateStorage(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/system.go b/api/v1/system.go index 766c1210..971cf895 100644 --- a/api/v1/system.go +++ b/api/v1/system.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -168,7 +167,7 @@ func (s *APIV1Service) GetSystemStatus(c echo.Context) error { // @Router /api/v1/system/vacuum [POST] func (s *APIV1Service) ExecVacuum(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/system_setting.go b/api/v1/system_setting.go index 5c4527c6..ca953f4d 100644 --- a/api/v1/system_setting.go +++ b/api/v1/system_setting.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" ) @@ -95,7 +94,7 @@ func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) { // @Router /api/v1/system/setting [GET] func (s *APIV1Service) GetSystemSettingList(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -138,7 +137,7 @@ func (s *APIV1Service) GetSystemSettingList(c echo.Context) error { // @Router /api/v1/system/setting [POST] func (s *APIV1Service) CreateSystemSetting(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/tag.go b/api/v1/tag.go index 50e7d59f..5a13d2de 100644 --- a/api/v1/tag.go +++ b/api/v1/tag.go @@ -9,7 +9,6 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) @@ -46,7 +45,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { // @Router /api/v1/tag [GET] func (s *APIV1Service) GetTagList(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") } @@ -80,7 +79,7 @@ func (s *APIV1Service) GetTagList(c echo.Context) error { // @Router /api/v1/tag [POST] func (s *APIV1Service) CreateTag(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -122,7 +121,7 @@ func (s *APIV1Service) CreateTag(c echo.Context) error { // @Router /api/v1/tag/delete [POST] func (s *APIV1Service) DeleteTag(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -157,7 +156,7 @@ func (s *APIV1Service) DeleteTag(c echo.Context) error { // @Router /api/v1/tag/suggestion [GET] func (s *APIV1Service) GetTagSuggestion(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user session") } diff --git a/api/v1/user.go b/api/v1/user.go index e2e192a4..c971c28d 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -8,7 +8,6 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "golang.org/x/crypto/bcrypt" @@ -119,7 +118,7 @@ func (s *APIV1Service) GetUserList(c echo.Context) error { // @Router /api/v1/user [POST] func (s *APIV1Service) CreateUser(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } @@ -184,7 +183,7 @@ func (s *APIV1Service) CreateUser(c echo.Context) error { // @Router /api/v1/user/me [GET] func (s *APIV1Service) GetCurrentUser(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } @@ -287,7 +286,7 @@ func (s *APIV1Service) GetUserByID(c echo.Context) error { // @Router /api/v1/user/{id} [DELETE] func (s *APIV1Service) DeleteUser(c echo.Context) error { ctx := c.Request().Context() - currentUserID, ok := c.Get(auth.UserIDContextKey).(int32) + currentUserID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -337,7 +336,7 @@ func (s *APIV1Service) UpdateUser(c echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err) } - currentUserID, ok := c.Get(auth.UserIDContextKey).(int32) + currentUserID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/user_setting.go b/api/v1/user_setting.go index 15c3086c..4c61182b 100644 --- a/api/v1/user_setting.go +++ b/api/v1/user_setting.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/usememos/memos/api/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) @@ -97,7 +96,7 @@ func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) { // @Router /api/v1/user/setting [POST] func (s *APIV1Service) UpsertUserSetting(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(auth.UserIDContextKey).(int32) + userID, ok := c.Get(userIDContextKey).(int32) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } diff --git a/api/v2/acl.go b/api/v2/acl.go index 076c86bd..1affc904 100644 --- a/api/v2/acl.go +++ b/api/v2/acl.go @@ -3,14 +3,11 @@ package v2 import ( "context" "net/http" - "strconv" "strings" - "time" "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" "github.com/usememos/memos/api/auth" - "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -22,9 +19,9 @@ import ( type ContextKey int const ( - // The key name used to store user id in the context + // The key name used to store username in the context // user id is extracted from the jwt token subject field. - UserIDContextKey ContextKey = iota + usernameContextKey ContextKey = iota ) // GRPCAuthInterceptor is the auth interceptor for gRPC server. @@ -52,7 +49,7 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, status.Errorf(codes.Unauthenticated, err.Error()) } - userID, err := in.authenticate(ctx, accessTokenStr) + username, err := in.authenticate(ctx, accessTokenStr) if err != nil { if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { return handler(ctx, request) @@ -60,28 +57,28 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, err } user, err := in.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, + Username: &username, }) if err != nil { return nil, errors.Wrap(err, "failed to get user") } if user == nil { - return nil, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID) + return nil, errors.Errorf("user %q not exists", username) } if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin { - return nil, status.Errorf(codes.PermissionDenied, "user ID %q is not admin", userID) + return nil, errors.Errorf("user %q is not admin", username) } // Stores userID into context. - childCtx := context.WithValue(ctx, UserIDContextKey, userID) + childCtx := context.WithValue(ctx, usernameContextKey, username) return handler(childCtx, request) } -func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (int32, error) { +func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (string, error) { if accessTokenStr == "" { - return 0, status.Errorf(codes.Unauthenticated, "access token not found") + return "", status.Errorf(codes.Unauthenticated, "access token not found") } - claims := &claimsMessage{} + claims := &auth.ClaimsMessage{} _, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -94,34 +91,31 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) }) if err != nil { - return 0, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") + return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token") } if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { - return 0, status.Errorf(codes.Unauthenticated, + return "", status.Errorf(codes.Unauthenticated, "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", claims.Audience, auth.AccessTokenAudienceName, ) } - userID, err := util.ConvertStringToInt32(claims.Subject) - if err != nil { - return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) - } + username := claims.Name user, err := in.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, + Username: &username, }) if err != nil { - return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID) + return "", errors.Wrap(err, "failed to get user") } if user == nil { - return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID) + return "", errors.Errorf("user %q not exists in the access token", username) } if user.RowStatus == store.Archived { - return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID) + return "", errors.Errorf("user %q is archived", username) } - return userID, nil + return username, nil } func getTokenFromMetadata(md metadata.MD) (string, error) { @@ -154,41 +148,3 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { } return false } - -type claimsMessage struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(auth.AccessTokenDuration) - return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &claimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{aud}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: auth.Issuer, - Subject: strconv.Itoa(userID), - }, - } - - // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token.Header["kid"] = auth.KeyID - - // Create the JWT string. - tokenString, err := token.SignedString(secret) - if err != nil { - return "", err - } - - return tokenString, nil -} diff --git a/api/v2/common.go b/api/v2/common.go index 0115f813..fac53be7 100644 --- a/api/v2/common.go +++ b/api/v2/common.go @@ -1,6 +1,8 @@ package v2 import ( + "context" + apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/store" ) @@ -26,3 +28,17 @@ func convertRowStatusToStore(rowStatus apiv2pb.RowStatus) store.RowStatus { return store.Normal } } + +func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) { + username, ok := ctx.Value(usernameContextKey).(string) + if !ok { + return nil, nil + } + user, err := s.GetUser(ctx, &store.FindUser{ + Username: &username, + }) + if err != nil { + return nil, err + } + return user, nil +} diff --git a/api/v2/memo_service.go b/api/v2/memo_service.go index 5e2f7391..9b798639 100644 --- a/api/v2/memo_service.go +++ b/api/v2/memo_service.go @@ -42,9 +42,9 @@ func (s *MemoService) ListMemos(ctx context.Context, request *apiv2pb.ListMemosR memoFind.CreatedTsAfter = filter.CreatedTsAfter } } - userIDPtr := ctx.Value(UserIDContextKey) + user, _ := getCurrentUser(ctx, s.Store) // If the user is not authenticated, only public memos are visible. - if userIDPtr == nil { + if user == nil { memoFind.VisibilityList = []store.Visibility{store.Public} } if request.PageSize != 0 { @@ -80,12 +80,14 @@ func (s *MemoService) GetMemo(ctx context.Context, request *apiv2pb.GetMemoReque return nil, status.Errorf(codes.NotFound, "memo not found") } if memo.Visibility != store.Public { - userIDPtr := ctx.Value(UserIDContextKey) - if userIDPtr == nil { - return nil, status.Errorf(codes.Unauthenticated, "unauthenticated") + user, err := getCurrentUser(ctx, s.Store) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user") } - userID := userIDPtr.(int32) - if memo.Visibility == store.Private && memo.CreatorID != userID { + if user == nil { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + if memo.Visibility == store.Private && memo.CreatorID != user.ID { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } } diff --git a/api/v2/system_service.go b/api/v2/system_service.go index 4eef0b4f..38bee07e 100644 --- a/api/v2/system_service.go +++ b/api/v2/system_service.go @@ -31,22 +31,16 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI defaultSystemInfo := &apiv2pb.SystemInfo{} // Get the database size if the user is a host. - userIDPtr := ctx.Value(UserIDContextKey) - if userIDPtr != nil { - userID := userIDPtr.(int32) - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) + currentUser, err := getCurrentUser(ctx, s.Store) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if currentUser != nil && currentUser.Role == store.RoleHost { + fi, err := os.Stat(s.Profile.DSN) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) - } - if user != nil && user.Role == store.RoleHost { - fi, err := os.Stat(s.Profile.DSN) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err) - } - defaultSystemInfo.DbSize = fi.Size() + return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err) } + defaultSystemInfo.DbSize = fi.Size() } response := &apiv2pb.GetSystemInfoResponse{ @@ -56,12 +50,9 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI } func (s *SystemService) UpdateSystemInfo(ctx context.Context, request *apiv2pb.UpdateSystemInfoRequest) (*apiv2pb.UpdateSystemInfoResponse, error) { - userID := ctx.Value(UserIDContextKey).(int32) - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) + user, err := getCurrentUser(ctx, s.Store) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } if user.Role != store.RoleHost { return nil, status.Errorf(codes.PermissionDenied, "permission denied") diff --git a/api/v2/user_service.go b/api/v2/user_service.go index 32ff29a9..e9ed71e8 100644 --- a/api/v2/user_service.go +++ b/api/v2/user_service.go @@ -5,11 +5,16 @@ import ( "net/http" "time" + "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" + "github.com/pkg/errors" + "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" "golang.org/x/crypto/bcrypt" + "golang.org/x/exp/slices" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -18,13 +23,15 @@ import ( type UserService struct { apiv2pb.UnimplementedUserServiceServer - Store *store.Store + Store *store.Store + Secret string } // NewUserService creates a new UserService. -func NewUserService(store *store.Store) *UserService { +func NewUserService(store *store.Store, secret string) *UserService { return &UserService{ - Store: store, + Store: store, + Secret: secret, } } @@ -40,13 +47,10 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque } userMessage := convertUserFromStore(user) - userIDPtr := ctx.Value(UserIDContextKey) - if userIDPtr != nil { - userID := userIDPtr.(int32) - if userID != userMessage.Id { - // Data desensitization. - userMessage.OpenId = "" - } + currentUser, _ := getCurrentUser(ctx, s.Store) + if currentUser == nil || currentUser.ID != user.ID { + // Data desensitization. + userMessage.OpenId = "" } response := &apiv2pb.GetUserResponse{ @@ -56,14 +60,11 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque } func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) { - userID := ctx.Value(UserIDContextKey).(int32) - currentUser, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) + currentUser, err := getCurrentUser(ctx, s.Store) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { + if currentUser.Username != request.Username && currentUser.Role != store.RoleAdmin { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } if request.UpdateMask == nil || len(request.UpdateMask) == 0 { @@ -72,7 +73,7 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse currentTs := time.Now().Unix() update := &store.UpdateUser{ - ID: userID, + ID: currentUser.ID, UpdatedTs: ¤tTs, } for _, path := range request.UpdateMask { @@ -116,6 +117,162 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse return response, nil } +func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) { + user, err := getCurrentUser(ctx, s.Store) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if user == nil || user.Username != request.Username { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + + userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err) + } + + accessTokens := []*apiv2pb.UserAccessToken{} + for _, userAccessToken := range userAccessTokens { + claims := &auth.ClaimsMessage{} + _, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(s.Secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + // If the access token is invalid or expired, just ignore it. + continue + } + + userAccessToken := &apiv2pb.UserAccessToken{ + AccessToken: userAccessToken.AccessToken, + Description: userAccessToken.Description, + IssuedAt: timestamppb.New(claims.IssuedAt.Time), + } + if claims.ExpiresAt != nil { + userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time) + } + accessTokens = append(accessTokens, userAccessToken) + } + + // Sort by issued time in descending order. + slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) bool { + return i.IssuedAt.Seconds > j.IssuedAt.Seconds + }) + response := &apiv2pb.ListUserAccessTokensResponse{ + AccessTokens: accessTokens, + } + return response, nil +} + +func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) { + user, err := getCurrentUser(ctx, s.Store) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + + accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) + } + + claims := &auth.ClaimsMessage{} + _, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(s.Secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err) + } + + // Upsert the access token to user setting store. + if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.UserAccessToken.Description); err != nil { + return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err) + } + + userAccessToken := &apiv2pb.UserAccessToken{ + AccessToken: accessToken, + Description: request.UserAccessToken.Description, + IssuedAt: timestamppb.New(claims.IssuedAt.Time), + } + if claims.ExpiresAt != nil { + userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time) + } + response := &apiv2pb.CreateUserAccessTokenResponse{ + AccessToken: userAccessToken, + } + return response, nil +} + +func (s *UserService) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) { + user, err := getCurrentUser(ctx, s.Store) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + + userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err) + } + updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{} + for _, userAccessToken := range userAccessTokens { + if userAccessToken.AccessToken == request.AccessToken { + continue + } + updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken) + } + if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + Value: &storepb.UserSetting_AccessTokens{ + AccessTokens: &storepb.AccessTokensUserSetting{ + AccessTokens: updatedUserAccessTokens, + }, + }, + }); err != nil { + return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err) + } + + return &apiv2pb.DeleteUserAccessTokenResponse{}, nil +} + +func (s *UserService) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error { + userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID) + if err != nil { + return errors.Wrap(err, "failed to get user access tokens") + } + userAccessToken := storepb.AccessTokensUserSetting_AccessToken{ + AccessToken: accessToken, + Description: description, + } + userAccessTokens = append(userAccessTokens, &userAccessToken) + if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + Value: &storepb.UserSetting_AccessTokens{ + AccessTokens: &storepb.AccessTokensUserSetting{ + AccessTokens: userAccessTokens, + }, + }, + }); err != nil { + return errors.Wrap(err, "failed to upsert user setting") + } + return nil +} + func convertUserFromStore(user *store.User) *apiv2pb.User { return &apiv2pb.User{ Id: int32(user.ID), diff --git a/api/v2/v2.go b/api/v2/v2.go index a2ba0832..55e330ff 100644 --- a/api/v2/v2.go +++ b/api/v2/v2.go @@ -30,7 +30,7 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store ), ) apiv2pb.RegisterSystemServiceServer(grpcServer, NewSystemService(profile, store)) - apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store)) + apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store, secret)) apiv2pb.RegisterMemoServiceServer(grpcServer, NewMemoService(store)) apiv2pb.RegisterTagServiceServer(grpcServer, NewTagService(store)) diff --git a/proto/gen/store/README.md b/proto/gen/store/README.md index 0e87c071..bf36420a 100644 --- a/proto/gen/store/README.md +++ b/proto/gen/store/README.md @@ -146,6 +146,7 @@ | Name | Number | Description | | ---- | ------ | ----------- | | USER_SETTING_KEY_UNSPECIFIED | 0 | | +| USER_SETTING_ACCESS_TOKENS | 1 | Access tokens for the user. | diff --git a/proto/gen/store/user_setting.pb.go b/proto/gen/store/user_setting.pb.go index 203cc890..c2e6419e 100644 --- a/proto/gen/store/user_setting.pb.go +++ b/proto/gen/store/user_setting.pb.go @@ -24,15 +24,19 @@ type UserSettingKey int32 const ( UserSettingKey_USER_SETTING_KEY_UNSPECIFIED UserSettingKey = 0 + // Access tokens for the user. + UserSettingKey_USER_SETTING_ACCESS_TOKENS UserSettingKey = 1 ) // Enum value maps for UserSettingKey. var ( UserSettingKey_name = map[int32]string{ 0: "USER_SETTING_KEY_UNSPECIFIED", + 1: "USER_SETTING_ACCESS_TOKENS", } UserSettingKey_value = map[string]int32{ "USER_SETTING_KEY_UNSPECIFIED": 0, + "USER_SETTING_ACCESS_TOKENS": 1, } ) @@ -279,10 +283,12 @@ var file_store_user_setting_proto_rawDesc = []byte{ 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x32, 0x0a, 0x0e, + 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x52, 0x0a, 0x0e, 0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x4b, 0x65, 0x79, 0x12, 0x20, 0x0a, 0x1c, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47, 0x5f, 0x4b, 0x45, 0x59, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, + 0x12, 0x1e, 0x0a, 0x1a, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47, + 0x5f, 0x41, 0x43, 0x43, 0x45, 0x53, 0x53, 0x5f, 0x54, 0x4f, 0x4b, 0x45, 0x4e, 0x53, 0x10, 0x01, 0x42, 0x9b, 0x01, 0x0a, 0x0f, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x65, 0x6d, 0x6f, 0x73, 0x2e, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x42, 0x10, 0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, diff --git a/proto/store/user_setting.proto b/proto/store/user_setting.proto index b4607af1..6d46fc98 100644 --- a/proto/store/user_setting.proto +++ b/proto/store/user_setting.proto @@ -16,6 +16,9 @@ message UserSetting { enum UserSettingKey { USER_SETTING_KEY_UNSPECIFIED = 0; + + // Access tokens for the user. + USER_SETTING_ACCESS_TOKENS = 1; } message AccessTokensUserSetting { diff --git a/store/user_setting.go b/store/user_setting.go index f577918b..1d8301af 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -3,7 +3,11 @@ package store import ( "context" "database/sql" + "errors" "strings" + + storepb "github.com/usememos/memos/proto/gen/store" + "google.golang.org/protobuf/encoding/protojson" ) type UserSetting struct { @@ -102,6 +106,138 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use return userSetting, nil } +type FindUserSettingV1 struct { + UserID *int32 + Key storepb.UserSettingKey +} + +func (s *Store) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { + stmt := ` + INSERT INTO user_setting ( + user_id, key, value + ) + VALUES (?, ?, ?) + ON CONFLICT(user_id, key) DO UPDATE + SET value = EXCLUDED.value + ` + var valueString string + if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else { + return nil, errors.New("invalid user setting key") + } + + if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil { + return nil, err + } + + userSettingMessage := upsert + s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage) + return userSettingMessage, nil +} + +func (s *Store) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { + where, args = append(where, "key = ?"), append(args, v.String()) + } + if v := find.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *find.UserID) + } + + query := ` + SELECT + user_id, + key, + value + FROM user_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + userSettingList := make([]*storepb.UserSetting, 0) + for rows.Next() { + userSetting := &storepb.UserSetting{} + var keyString, valueString string + if err := rows.Scan( + &userSetting.UserId, + &keyString, + &valueString, + ); err != nil { + return nil, err + } + userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) + if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + accessTokensUserSetting := &storepb.AccessTokensUserSetting{} + if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_AccessTokens{ + AccessTokens: accessTokensUserSetting, + } + } else { + // Skip unknown user setting v1 key. + continue + } + userSettingList = append(userSettingList, userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + for _, userSetting := range userSettingList { + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + } + return userSettingList, nil +} + +func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) (*storepb.UserSetting, error) { + if find.UserID != nil { + if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok { + return cache.(*storepb.UserSetting), nil + } + } + + list, err := s.ListUserSettingsV1(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + userSetting := list[0] + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + return userSetting, nil +} + +// GetUserAccessTokens returns the access tokens of the user. +func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) { + userSetting, err := s.GetUserSettingV1(ctx, &FindUserSettingV1{ + UserID: &userID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + }) + if err != nil { + return nil, err + } + if userSetting == nil { + return []*storepb.AccessTokensUserSetting_AccessToken{}, nil + } + + accessTokensUserSetting := userSetting.GetAccessTokens() + return accessTokensUserSetting.AccessTokens, nil +} + func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/web/src/types/proto/store/user_setting_pb.d.ts b/web/src/types/proto/store/user_setting_pb.d.ts index 1c6fd18d..bb2af523 100644 --- a/web/src/types/proto/store/user_setting_pb.d.ts +++ b/web/src/types/proto/store/user_setting_pb.d.ts @@ -13,7 +13,14 @@ export declare enum UserSettingKey { /** * @generated from enum value: USER_SETTING_KEY_UNSPECIFIED = 0; */ - UNSPECIFIED = 0, + USER_SETTING_KEY_UNSPECIFIED = 0, + + /** + * Access tokens for the user. + * + * @generated from enum value: USER_SETTING_ACCESS_TOKENS = 1; + */ + USER_SETTING_ACCESS_TOKENS = 1, } /** diff --git a/web/src/types/proto/store/user_setting_pb.js b/web/src/types/proto/store/user_setting_pb.js index 0e2a875d..eea4ee5d 100644 --- a/web/src/types/proto/store/user_setting_pb.js +++ b/web/src/types/proto/store/user_setting_pb.js @@ -11,7 +11,8 @@ import { proto3 } from "@bufbuild/protobuf"; export const UserSettingKey = proto3.makeEnum( "memos.store.UserSettingKey", [ - {no: 0, name: "USER_SETTING_KEY_UNSPECIFIED", localName: "UNSPECIFIED"}, + {no: 0, name: "USER_SETTING_KEY_UNSPECIFIED"}, + {no: 1, name: "USER_SETTING_ACCESS_TOKENS"}, ], );