feat: impl user access token api

This commit is contained in:
Steven 2023-09-14 20:16:17 +08:00
parent 41e26f56e9
commit 42bd9b194b
26 changed files with 507 additions and 240 deletions

View file

@ -1,13 +1,13 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v4"
)
const (
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
UserIDContextKey = "user-id"
// issuer is the issuer of the jwt token.
Issuer = "memos"
// Signing key section. For now, this is only used for signing, not for verifying since we only
@ -23,3 +23,42 @@ const (
// AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "memos.access-token"
)
type ClaimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token.
// username is the email of the user.
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret string) (string, error) {
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret))
}
// generateToken generates a jwt token.
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
registeredClaims := jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{audience},
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: fmt.Sprint(userID),
}
if expirationTime.After(time.Now()) {
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
Name: username,
RegisteredClaims: registeredClaims,
})
token.Header["kid"] = KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

View file

@ -5,9 +5,11 @@ import (
"fmt"
"net/http"
"regexp"
"time"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
@ -94,12 +96,15 @@ func (s *APIV1Service) SignIn(c echo.Context) error {
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
}
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.createAuthSignInActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
@ -213,12 +218,15 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
}
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.createAuthSignInActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
@ -304,13 +312,15 @@ func (s *APIV1Service) SignUp(c echo.Context) error {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.createAuthSignUpActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
@ -358,3 +368,22 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User
}
return err
}
// RemoveTokensAndCookies removes the jwt token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}

View file

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
)
@ -88,7 +87,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
}
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
isHostUser := false
if ok {
user, err := s.Store.GetUser(ctx, &store.FindUser{
@ -129,7 +128,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error {
// @Router /api/v1/idp [POST]
func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -177,7 +176,7 @@ func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error {
// @Router /api/v1/idp/{idpId} [GET]
func (s *APIV1Service) GetIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -224,7 +223,7 @@ func (s *APIV1Service) GetIdentityProvider(c echo.Context) error {
// @Router /api/v1/idp/{idpId} [DELETE]
func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -266,7 +265,7 @@ func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error {
// @Router /api/v1/idp/{idpId} [PATCH]
func (s *APIV1Service) UpdateIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -4,7 +4,6 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
@ -14,75 +13,11 @@ import (
"github.com/usememos/memos/store"
)
type claimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate access token")
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
return nil
}
// RemoveTokensAndCookies removes the jwt token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}
// generateToken generates a jwt token.
func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &claimsMessage{
Name: username,
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{aud},
// In JWT, the expiry time is expressed as unix milliseconds.
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: auth.Issuer,
Subject: fmt.Sprintf("%d", userID),
},
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = auth.KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}
const (
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
userIDContextKey = "user-id"
)
func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
@ -149,7 +84,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
}
claims := &claimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
@ -163,7 +98,6 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
})
if err != nil {
RemoveTokensAndCookies(c)
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
}
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
@ -188,7 +122,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
}
// Stores userID into context.
c.Set(auth.UserIDContextKey, userID)
c.Set(userIDContextKey, userID)
return next(c)
}
}
@ -213,7 +147,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
}
if user != nil {
// Stores userID into context.
c.Set(auth.UserIDContextKey, user.ID)
c.Set(userIDContextKey, user.ID)
return true
}
}

View file

@ -10,7 +10,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/log"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
@ -156,7 +155,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error {
}
}
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
// Anonymous use should only fetch PUBLIC memos with specified user
if findMemoMessage.CreatorID == nil {
@ -247,7 +246,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error {
// - It's currently possible to create phantom resources and relations. Phantom relations will trigger backend 404's when fetching memo.
func (s *APIV1Service) CreateMemo(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -407,7 +406,7 @@ func (s *APIV1Service) CreateMemo(c echo.Context) error {
func (s *APIV1Service) GetAllMemos(c echo.Context) error {
ctx := c.Request().Context()
findMemoMessage := &store.FindMemo{}
_, ok := c.Get(auth.UserIDContextKey).(int32)
_, ok := c.Get(userIDContextKey).(int32)
if !ok {
findMemoMessage.VisibilityList = []store.Visibility{store.Public}
} else {
@ -481,7 +480,7 @@ func (s *APIV1Service) GetMemoStats(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
}
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
findMemoMessage.VisibilityList = []store.Visibility{store.Public}
} else {
@ -548,7 +547,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if memo.Visibility == store.Private {
if !ok || memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusForbidden, "this memo is private only")
@ -580,7 +579,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error {
// @Router /api/v1/memo/{memoId} [DELETE]
func (s *APIV1Service) DeleteMemo(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -633,7 +632,7 @@ func (s *APIV1Service) DeleteMemo(c echo.Context) error {
// - Passing 0 to createdTs and updatedTs will set them to 0 in the database, which is probably unwanted.
func (s *APIV1Service) UpdateMemo(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
)
@ -47,7 +46,7 @@ func (s *APIV1Service) CreateMemoOrganizer(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -7,7 +7,6 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
)
@ -95,7 +94,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -145,7 +144,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error {
// @Router /api/v1/memo/{memoId}/resource/{resourceId} [DELETE]
func (s *APIV1Service) UnbindMemoResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -20,7 +20,6 @@ import (
"github.com/disintegration/imaging"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/log"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/storage/s3"
@ -105,7 +104,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) {
// @Router /api/v1/resource [GET]
func (s *APIV1Service) GetResourceList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -145,7 +144,7 @@ func (s *APIV1Service) GetResourceList(c echo.Context) error {
// @Router /api/v1/resource [POST]
func (s *APIV1Service) CreateResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -197,7 +196,7 @@ func (s *APIV1Service) CreateResource(c echo.Context) error {
// @Router /api/v1/resource/blob [POST]
func (s *APIV1Service) UploadResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -270,7 +269,7 @@ func (s *APIV1Service) UploadResource(c echo.Context) error {
// @Router /api/v1/resource/{resourceId} [DELETE]
func (s *APIV1Service) DeleteResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -327,7 +326,7 @@ func (s *APIV1Service) DeleteResource(c echo.Context) error {
// @Router /api/v1/resource/{resourceId} [PATCH]
func (s *APIV1Service) UpdateResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -398,7 +397,7 @@ func (s *APIV1Service) streamResource(c echo.Context) error {
}
// Protected resource require a logined user
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if resourceVisibility == store.Protected && (!ok || userID <= 0) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
}

View file

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
)
@ -82,7 +81,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) {
// @Router /api/v1/storage [GET]
func (s *APIV1Service) GetStorageList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -129,7 +128,7 @@ func (s *APIV1Service) GetStorageList(c echo.Context) error {
// @Router /api/v1/storage [POST]
func (s *APIV1Service) CreateStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -190,7 +189,7 @@ func (s *APIV1Service) CreateStorage(c echo.Context) error {
// - error message "Storage service %d is using" probably should be "Storage service %d is in use".
func (s *APIV1Service) DeleteStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -246,7 +245,7 @@ func (s *APIV1Service) DeleteStorage(c echo.Context) error {
// @Router /api/v1/storage/{storageId} [PATCH]
func (s *APIV1Service) UpdateStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -5,7 +5,6 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/log"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
@ -168,7 +167,7 @@ func (s *APIV1Service) GetSystemStatus(c echo.Context) error {
// @Router /api/v1/system/vacuum [POST]
func (s *APIV1Service) ExecVacuum(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -7,7 +7,6 @@ import (
"strings"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/store"
)
@ -95,7 +94,7 @@ func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) {
// @Router /api/v1/system/setting [GET]
func (s *APIV1Service) GetSystemSettingList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -138,7 +137,7 @@ func (s *APIV1Service) GetSystemSettingList(c echo.Context) error {
// @Router /api/v1/system/setting [POST]
func (s *APIV1Service) CreateSystemSetting(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -9,7 +9,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/store"
"golang.org/x/exp/slices"
)
@ -46,7 +45,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) {
// @Router /api/v1/tag [GET]
func (s *APIV1Service) GetTagList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}
@ -80,7 +79,7 @@ func (s *APIV1Service) GetTagList(c echo.Context) error {
// @Router /api/v1/tag [POST]
func (s *APIV1Service) CreateTag(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -122,7 +121,7 @@ func (s *APIV1Service) CreateTag(c echo.Context) error {
// @Router /api/v1/tag/delete [POST]
func (s *APIV1Service) DeleteTag(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -157,7 +156,7 @@ func (s *APIV1Service) DeleteTag(c echo.Context) error {
// @Router /api/v1/tag/suggestion [GET]
func (s *APIV1Service) GetTagSuggestion(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user session")
}

View file

@ -8,7 +8,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt"
@ -119,7 +118,7 @@ func (s *APIV1Service) GetUserList(c echo.Context) error {
// @Router /api/v1/user [POST]
func (s *APIV1Service) CreateUser(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
@ -184,7 +183,7 @@ func (s *APIV1Service) CreateUser(c echo.Context) error {
// @Router /api/v1/user/me [GET]
func (s *APIV1Service) GetCurrentUser(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
@ -287,7 +286,7 @@ func (s *APIV1Service) GetUserByID(c echo.Context) error {
// @Router /api/v1/user/{id} [DELETE]
func (s *APIV1Service) DeleteUser(c echo.Context) error {
ctx := c.Request().Context()
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
@ -337,7 +336,7 @@ func (s *APIV1Service) UpdateUser(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}

View file

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/store"
"golang.org/x/exp/slices"
)
@ -97,7 +96,7 @@ func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) {
// @Router /api/v1/user/setting [POST]
func (s *APIV1Service) UpsertUserSetting(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(auth.UserIDContextKey).(int32)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}

View file

@ -3,14 +3,11 @@ package v2
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -22,9 +19,9 @@ import (
type ContextKey int
const (
// The key name used to store user id in the context
// The key name used to store username in the context
// user id is extracted from the jwt token subject field.
UserIDContextKey ContextKey = iota
usernameContextKey ContextKey = iota
)
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
@ -52,7 +49,7 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
return nil, status.Errorf(codes.Unauthenticated, err.Error())
}
userID, err := in.authenticate(ctx, accessTokenStr)
username, err := in.authenticate(ctx, accessTokenStr)
if err != nil {
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
@ -60,28 +57,28 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
return nil, err
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
Username: &username,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID)
return nil, errors.Errorf("user %q not exists", username)
}
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "user ID %q is not admin", userID)
return nil, errors.Errorf("user %q is not admin", username)
}
// Stores userID into context.
childCtx := context.WithValue(ctx, UserIDContextKey, userID)
childCtx := context.WithValue(ctx, usernameContextKey, username)
return handler(childCtx, request)
}
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (int32, error) {
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (string, error) {
if accessTokenStr == "" {
return 0, status.Errorf(codes.Unauthenticated, "access token not found")
return "", status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &claimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
@ -94,34 +91,31 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
return 0, status.Errorf(codes.Unauthenticated,
return "", status.Errorf(codes.Unauthenticated,
"invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
claims.Audience,
auth.AccessTokenAudienceName,
)
}
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject)
}
username := claims.Name
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
Username: &username,
})
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID)
return "", errors.Wrap(err, "failed to get user")
}
if user == nil {
return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID)
return "", errors.Errorf("user %q not exists in the access token", username)
}
if user.RowStatus == store.Archived {
return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID)
return "", errors.Errorf("user %q is archived", username)
}
return userID, nil
return username, nil
}
func getTokenFromMetadata(md metadata.MD) (string, error) {
@ -154,41 +148,3 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
}
return false
}
type claimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}
func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &claimsMessage{
Name: username,
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{aud},
// In JWT, the expiry time is expressed as unix milliseconds.
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: auth.Issuer,
Subject: strconv.Itoa(userID),
},
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = auth.KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

View file

@ -1,6 +1,8 @@
package v2
import (
"context"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
@ -26,3 +28,17 @@ func convertRowStatusToStore(rowStatus apiv2pb.RowStatus) store.RowStatus {
return store.Normal
}
}
func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) {
username, ok := ctx.Value(usernameContextKey).(string)
if !ok {
return nil, nil
}
user, err := s.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, err
}
return user, nil
}

View file

@ -42,9 +42,9 @@ func (s *MemoService) ListMemos(ctx context.Context, request *apiv2pb.ListMemosR
memoFind.CreatedTsAfter = filter.CreatedTsAfter
}
}
userIDPtr := ctx.Value(UserIDContextKey)
user, _ := getCurrentUser(ctx, s.Store)
// If the user is not authenticated, only public memos are visible.
if userIDPtr == nil {
if user == nil {
memoFind.VisibilityList = []store.Visibility{store.Public}
}
if request.PageSize != 0 {
@ -80,12 +80,14 @@ func (s *MemoService) GetMemo(ctx context.Context, request *apiv2pb.GetMemoReque
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
userIDPtr := ctx.Value(UserIDContextKey)
if userIDPtr == nil {
return nil, status.Errorf(codes.Unauthenticated, "unauthenticated")
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
userID := userIDPtr.(int32)
if memo.Visibility == store.Private && memo.CreatorID != userID {
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}

View file

@ -31,22 +31,16 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI
defaultSystemInfo := &apiv2pb.SystemInfo{}
// Get the database size if the user is a host.
userIDPtr := ctx.Value(UserIDContextKey)
if userIDPtr != nil {
userID := userIDPtr.(int32)
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser != nil && currentUser.Role == store.RoleHost {
fi, err := os.Stat(s.Profile.DSN)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user != nil && user.Role == store.RoleHost {
fi, err := os.Stat(s.Profile.DSN)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err)
}
defaultSystemInfo.DbSize = fi.Size()
return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err)
}
defaultSystemInfo.DbSize = fi.Size()
}
response := &apiv2pb.GetSystemInfoResponse{
@ -56,12 +50,9 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI
}
func (s *SystemService) UpdateSystemInfo(ctx context.Context, request *apiv2pb.UpdateSystemInfoRequest) (*apiv2pb.UpdateSystemInfoResponse, error) {
userID := ctx.Value(UserIDContextKey).(int32)
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")

View file

@ -5,11 +5,16 @@ import (
"net/http"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/common/util"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
@ -18,13 +23,15 @@ import (
type UserService struct {
apiv2pb.UnimplementedUserServiceServer
Store *store.Store
Store *store.Store
Secret string
}
// NewUserService creates a new UserService.
func NewUserService(store *store.Store) *UserService {
func NewUserService(store *store.Store, secret string) *UserService {
return &UserService{
Store: store,
Store: store,
Secret: secret,
}
}
@ -40,13 +47,10 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
}
userMessage := convertUserFromStore(user)
userIDPtr := ctx.Value(UserIDContextKey)
if userIDPtr != nil {
userID := userIDPtr.(int32)
if userID != userMessage.Id {
// Data desensitization.
userMessage.OpenId = ""
}
currentUser, _ := getCurrentUser(ctx, s.Store)
if currentUser == nil || currentUser.ID != user.ID {
// Data desensitization.
userMessage.OpenId = ""
}
response := &apiv2pb.GetUserResponse{
@ -56,14 +60,11 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
}
func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
userID := ctx.Value(UserIDContextKey).(int32)
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
if currentUser.Username != request.Username && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask) == 0 {
@ -72,7 +73,7 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse
currentTs := time.Now().Unix()
update := &store.UpdateUser{
ID: userID,
ID: currentUser.ID,
UpdatedTs: &currentTs,
}
for _, path := range request.UpdateMask {
@ -116,6 +117,162 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse
return response, nil
}
func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil || user.Username != request.Username {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
accessTokens := []*apiv2pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
// If the access token is invalid or expired, just ignore it.
continue
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
accessTokens = append(accessTokens, userAccessToken)
}
// Sort by issued time in descending order.
slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) bool {
return i.IssuedAt.Seconds > j.IssuedAt.Seconds
})
response := &apiv2pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}
func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &auth.ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}
// Upsert the access token to user setting store.
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.UserAccessToken.Description); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: accessToken,
Description: request.UserAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
response := &apiv2pb.CreateUserAccessTokenResponse{
AccessToken: userAccessToken,
}
return response, nil
}
func (s *UserService) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == request.AccessToken {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
}
func (s *UserService) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: description,
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}
func convertUserFromStore(user *store.User) *apiv2pb.User {
return &apiv2pb.User{
Id: int32(user.ID),

View file

@ -30,7 +30,7 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store
),
)
apiv2pb.RegisterSystemServiceServer(grpcServer, NewSystemService(profile, store))
apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store))
apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store, secret))
apiv2pb.RegisterMemoServiceServer(grpcServer, NewMemoService(store))
apiv2pb.RegisterTagServiceServer(grpcServer, NewTagService(store))

View file

@ -146,6 +146,7 @@
| Name | Number | Description |
| ---- | ------ | ----------- |
| USER_SETTING_KEY_UNSPECIFIED | 0 | |
| USER_SETTING_ACCESS_TOKENS | 1 | Access tokens for the user. |

View file

@ -24,15 +24,19 @@ type UserSettingKey int32
const (
UserSettingKey_USER_SETTING_KEY_UNSPECIFIED UserSettingKey = 0
// Access tokens for the user.
UserSettingKey_USER_SETTING_ACCESS_TOKENS UserSettingKey = 1
)
// Enum value maps for UserSettingKey.
var (
UserSettingKey_name = map[int32]string{
0: "USER_SETTING_KEY_UNSPECIFIED",
1: "USER_SETTING_ACCESS_TOKENS",
}
UserSettingKey_value = map[string]int32{
"USER_SETTING_KEY_UNSPECIFIED": 0,
"USER_SETTING_ACCESS_TOKENS": 1,
}
)
@ -279,10 +283,12 @@ var file_store_user_setting_proto_rawDesc = []byte{
0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61,
0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65,
0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x32, 0x0a, 0x0e,
0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x52, 0x0a, 0x0e,
0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x4b, 0x65, 0x79, 0x12, 0x20,
0x0a, 0x1c, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47, 0x5f, 0x4b,
0x45, 0x59, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00,
0x12, 0x1e, 0x0a, 0x1a, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47,
0x5f, 0x41, 0x43, 0x43, 0x45, 0x53, 0x53, 0x5f, 0x54, 0x4f, 0x4b, 0x45, 0x4e, 0x53, 0x10, 0x01,
0x42, 0x9b, 0x01, 0x0a, 0x0f, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x65, 0x6d, 0x6f, 0x73, 0x2e, 0x73,
0x74, 0x6f, 0x72, 0x65, 0x42, 0x10, 0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e,
0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,

View file

@ -16,6 +16,9 @@ message UserSetting {
enum UserSettingKey {
USER_SETTING_KEY_UNSPECIFIED = 0;
// Access tokens for the user.
USER_SETTING_ACCESS_TOKENS = 1;
}
message AccessTokensUserSetting {

View file

@ -3,7 +3,11 @@ package store
import (
"context"
"database/sql"
"errors"
"strings"
storepb "github.com/usememos/memos/proto/gen/store"
"google.golang.org/protobuf/encoding/protojson"
)
type UserSetting struct {
@ -102,6 +106,138 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use
return userSetting, nil
}
type FindUserSettingV1 struct {
UserID *int32
Key storepb.UserSettingKey
}
func (s *Store) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else {
return nil, errors.New("invalid user setting key")
}
if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
return nil, err
}
userSettingMessage := upsert
s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage)
return userSettingMessage, nil
}
func (s *Store) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
}
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
var keyString, valueString string
if err := rows.Scan(
&userSetting.UserId,
&keyString,
&valueString,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
return nil, err
}
userSetting.Value = &storepb.UserSetting_AccessTokens{
AccessTokens: accessTokensUserSetting,
}
} else {
// Skip unknown user setting v1 key.
continue
}
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
}
return userSettingList, nil
}
func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) (*storepb.UserSetting, error) {
if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok {
return cache.(*storepb.UserSetting), nil
}
}
list, err := s.ListUserSettingsV1(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
return userSetting, nil
}
// GetUserAccessTokens returns the access tokens of the user.
func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
userSetting, err := s.GetUserSettingV1(ctx, &FindUserSettingV1{
UserID: &userID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return []*storepb.AccessTokensUserSetting_AccessToken{}, nil
}
accessTokensUserSetting := userSetting.GetAccessTokens()
return accessTokensUserSetting.AccessTokens, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM

View file

@ -13,7 +13,14 @@ export declare enum UserSettingKey {
/**
* @generated from enum value: USER_SETTING_KEY_UNSPECIFIED = 0;
*/
UNSPECIFIED = 0,
USER_SETTING_KEY_UNSPECIFIED = 0,
/**
* Access tokens for the user.
*
* @generated from enum value: USER_SETTING_ACCESS_TOKENS = 1;
*/
USER_SETTING_ACCESS_TOKENS = 1,
}
/**

View file

@ -11,7 +11,8 @@ import { proto3 } from "@bufbuild/protobuf";
export const UserSettingKey = proto3.makeEnum(
"memos.store.UserSettingKey",
[
{no: 0, name: "USER_SETTING_KEY_UNSPECIFIED", localName: "UNSPECIFIED"},
{no: 0, name: "USER_SETTING_KEY_UNSPECIFIED"},
{no: 1, name: "USER_SETTING_ACCESS_TOKENS"},
],
);