mirror of
https://github.com/usememos/memos.git
synced 2024-12-26 23:22:47 +08:00
feat: impl user access token api
This commit is contained in:
parent
41e26f56e9
commit
42bd9b194b
26 changed files with 507 additions and 240 deletions
|
@ -1,13 +1,13 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
// The key name used to store user id in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
UserIDContextKey = "user-id"
|
||||
// issuer is the issuer of the jwt token.
|
||||
Issuer = "memos"
|
||||
// Signing key section. For now, this is only used for signing, not for verifying since we only
|
||||
|
@ -23,3 +23,42 @@ const (
|
|||
// AccessTokenCookieName is the cookie name of access token.
|
||||
AccessTokenCookieName = "memos.access-token"
|
||||
)
|
||||
|
||||
type ClaimsMessage struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token.
|
||||
// username is the email of the user.
|
||||
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret string) (string, error) {
|
||||
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret))
|
||||
}
|
||||
|
||||
// generateToken generates a jwt token.
|
||||
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{audience},
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Subject: fmt.Sprint(userID),
|
||||
}
|
||||
if expirationTime.After(time.Now()) {
|
||||
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: registeredClaims,
|
||||
})
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
|
|
@ -5,9 +5,11 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/plugin/idp/oauth2"
|
||||
|
@ -94,12 +96,15 @@ func (s *APIV1Service) SignIn(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
|
||||
}
|
||||
|
||||
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
|
||||
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
|
||||
}
|
||||
if err := s.createAuthSignInActivity(c, user); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
|
||||
}
|
||||
cookieExp := time.Now().Add(auth.CookieExpDuration)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
|
||||
userMessage := convertUserFromStore(user)
|
||||
return c.JSON(http.StatusOK, userMessage)
|
||||
}
|
||||
|
@ -213,12 +218,15 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
|
||||
}
|
||||
|
||||
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
|
||||
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
|
||||
}
|
||||
if err := s.createAuthSignInActivity(c, user); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
|
||||
}
|
||||
cookieExp := time.Now().Add(auth.CookieExpDuration)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
|
||||
userMessage := convertUserFromStore(user)
|
||||
return c.JSON(http.StatusOK, userMessage)
|
||||
}
|
||||
|
@ -304,13 +312,15 @@ func (s *APIV1Service) SignUp(c echo.Context) error {
|
|||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
||||
}
|
||||
if err := GenerateTokensAndSetCookies(c, user, s.Secret); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
|
||||
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
|
||||
}
|
||||
if err := s.createAuthSignUpActivity(c, user); err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
|
||||
}
|
||||
|
||||
cookieExp := time.Now().Add(auth.CookieExpDuration)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
|
||||
userMessage := convertUserFromStore(user)
|
||||
return c.JSON(http.StatusOK, userMessage)
|
||||
}
|
||||
|
@ -358,3 +368,22 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveTokensAndCookies removes the jwt token from the cookies.
|
||||
func RemoveTokensAndCookies(c echo.Context) {
|
||||
cookieExp := time.Now().Add(-1 * time.Hour)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
|
||||
}
|
||||
|
||||
// setTokenCookie sets the token to the cookie.
|
||||
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
|
||||
cookie := new(http.Cookie)
|
||||
cookie.Name = name
|
||||
cookie.Value = token
|
||||
cookie.Expires = expiration
|
||||
cookie.Path = "/"
|
||||
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
|
||||
cookie.HttpOnly = true
|
||||
cookie.SameSite = http.SameSiteStrictMode
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -88,7 +87,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
|
||||
}
|
||||
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
isHostUser := false
|
||||
if ok {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
|
@ -129,7 +128,7 @@ func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error {
|
|||
// @Router /api/v1/idp [POST]
|
||||
func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -177,7 +176,7 @@ func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error {
|
|||
// @Router /api/v1/idp/{idpId} [GET]
|
||||
func (s *APIV1Service) GetIdentityProvider(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -224,7 +223,7 @@ func (s *APIV1Service) GetIdentityProvider(c echo.Context) error {
|
|||
// @Router /api/v1/idp/{idpId} [DELETE]
|
||||
func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -266,7 +265,7 @@ func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error {
|
|||
// @Router /api/v1/idp/{idpId} [PATCH]
|
||||
func (s *APIV1Service) UpdateIdentityProvider(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
@ -14,75 +13,11 @@ import (
|
|||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type claimsMessage struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token for web.
|
||||
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
|
||||
expirationTime := time.Now().Add(auth.AccessTokenDuration)
|
||||
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
|
||||
}
|
||||
|
||||
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
|
||||
func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
|
||||
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate access token")
|
||||
}
|
||||
|
||||
cookieExp := time.Now().Add(auth.CookieExpDuration)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTokensAndCookies removes the jwt token from the cookies.
|
||||
func RemoveTokensAndCookies(c echo.Context) {
|
||||
cookieExp := time.Now().Add(-1 * time.Hour)
|
||||
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
|
||||
}
|
||||
|
||||
// setTokenCookie sets the token to the cookie.
|
||||
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
|
||||
cookie := new(http.Cookie)
|
||||
cookie.Name = name
|
||||
cookie.Value = token
|
||||
cookie.Expires = expiration
|
||||
cookie.Path = "/"
|
||||
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
|
||||
cookie.HttpOnly = true
|
||||
cookie.SameSite = http.SameSiteStrictMode
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// generateToken generates a jwt token.
|
||||
func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
// Create the JWT claims, which includes the username and expiry time.
|
||||
claims := &claimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{aud},
|
||||
// In JWT, the expiry time is expressed as unix milliseconds.
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: auth.Issuer,
|
||||
Subject: fmt.Sprintf("%d", userID),
|
||||
},
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["kid"] = auth.KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
const (
|
||||
// The key name used to store user id in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
userIDContextKey = "user-id"
|
||||
)
|
||||
|
||||
func extractTokenFromHeader(c echo.Context) (string, error) {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
|
@ -149,7 +84,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
|||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
|
||||
}
|
||||
|
||||
claims := &claimsMessage{}
|
||||
claims := &auth.ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
|
@ -163,7 +98,6 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
RemoveTokensAndCookies(c)
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
|
||||
}
|
||||
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
|
||||
|
@ -188,7 +122,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
|||
}
|
||||
|
||||
// Stores userID into context.
|
||||
c.Set(auth.UserIDContextKey, userID)
|
||||
c.Set(userIDContextKey, userID)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
@ -213,7 +147,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
|
|||
}
|
||||
if user != nil {
|
||||
// Stores userID into context.
|
||||
c.Set(auth.UserIDContextKey, user.ID)
|
||||
c.Set(userIDContextKey, user.ID)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/log"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
|
@ -156,7 +155,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
currentUserID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
// Anonymous use should only fetch PUBLIC memos with specified user
|
||||
if findMemoMessage.CreatorID == nil {
|
||||
|
@ -247,7 +246,7 @@ func (s *APIV1Service) GetMemoList(c echo.Context) error {
|
|||
// - It's currently possible to create phantom resources and relations. Phantom relations will trigger backend 404's when fetching memo.
|
||||
func (s *APIV1Service) CreateMemo(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -407,7 +406,7 @@ func (s *APIV1Service) CreateMemo(c echo.Context) error {
|
|||
func (s *APIV1Service) GetAllMemos(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
findMemoMessage := &store.FindMemo{}
|
||||
_, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
_, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
findMemoMessage.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
|
@ -481,7 +480,7 @@ func (s *APIV1Service) GetMemoStats(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
|
||||
}
|
||||
|
||||
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
currentUserID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
findMemoMessage.VisibilityList = []store.Visibility{store.Public}
|
||||
} else {
|
||||
|
@ -548,7 +547,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
|
||||
}
|
||||
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if memo.Visibility == store.Private {
|
||||
if !ok || memo.CreatorID != userID {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "this memo is private only")
|
||||
|
@ -580,7 +579,7 @@ func (s *APIV1Service) GetMemo(c echo.Context) error {
|
|||
// @Router /api/v1/memo/{memoId} [DELETE]
|
||||
func (s *APIV1Service) DeleteMemo(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -633,7 +632,7 @@ func (s *APIV1Service) DeleteMemo(c echo.Context) error {
|
|||
// - Passing 0 to createdTs and updatedTs will set them to 0 in the database, which is probably unwanted.
|
||||
func (s *APIV1Service) UpdateMemo(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -47,7 +46,7 @@ func (s *APIV1Service) CreateMemoOrganizer(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
|
||||
}
|
||||
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -95,7 +94,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
|
||||
}
|
||||
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -145,7 +144,7 @@ func (s *APIV1Service) BindMemoResource(c echo.Context) error {
|
|||
// @Router /api/v1/memo/{memoId}/resource/{resourceId} [DELETE]
|
||||
func (s *APIV1Service) UnbindMemoResource(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"github.com/disintegration/imaging"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/log"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
|
@ -105,7 +104,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) {
|
|||
// @Router /api/v1/resource [GET]
|
||||
func (s *APIV1Service) GetResourceList(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -145,7 +144,7 @@ func (s *APIV1Service) GetResourceList(c echo.Context) error {
|
|||
// @Router /api/v1/resource [POST]
|
||||
func (s *APIV1Service) CreateResource(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -197,7 +196,7 @@ func (s *APIV1Service) CreateResource(c echo.Context) error {
|
|||
// @Router /api/v1/resource/blob [POST]
|
||||
func (s *APIV1Service) UploadResource(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -270,7 +269,7 @@ func (s *APIV1Service) UploadResource(c echo.Context) error {
|
|||
// @Router /api/v1/resource/{resourceId} [DELETE]
|
||||
func (s *APIV1Service) DeleteResource(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -327,7 +326,7 @@ func (s *APIV1Service) DeleteResource(c echo.Context) error {
|
|||
// @Router /api/v1/resource/{resourceId} [PATCH]
|
||||
func (s *APIV1Service) UpdateResource(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -398,7 +397,7 @@ func (s *APIV1Service) streamResource(c echo.Context) error {
|
|||
}
|
||||
|
||||
// Protected resource require a logined user
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if resourceVisibility == store.Protected && (!ok || userID <= 0) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -82,7 +81,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) {
|
|||
// @Router /api/v1/storage [GET]
|
||||
func (s *APIV1Service) GetStorageList(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -129,7 +128,7 @@ func (s *APIV1Service) GetStorageList(c echo.Context) error {
|
|||
// @Router /api/v1/storage [POST]
|
||||
func (s *APIV1Service) CreateStorage(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -190,7 +189,7 @@ func (s *APIV1Service) CreateStorage(c echo.Context) error {
|
|||
// - error message "Storage service %d is using" probably should be "Storage service %d is in use".
|
||||
func (s *APIV1Service) DeleteStorage(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -246,7 +245,7 @@ func (s *APIV1Service) DeleteStorage(c echo.Context) error {
|
|||
// @Router /api/v1/storage/{storageId} [PATCH]
|
||||
func (s *APIV1Service) UpdateStorage(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/log"
|
||||
"github.com/usememos/memos/server/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
|
@ -168,7 +167,7 @@ func (s *APIV1Service) GetSystemStatus(c echo.Context) error {
|
|||
// @Router /api/v1/system/vacuum [POST]
|
||||
func (s *APIV1Service) ExecVacuum(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
@ -95,7 +94,7 @@ func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) {
|
|||
// @Router /api/v1/system/setting [GET]
|
||||
func (s *APIV1Service) GetSystemSettingList(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -138,7 +137,7 @@ func (s *APIV1Service) GetSystemSettingList(c echo.Context) error {
|
|||
// @Router /api/v1/system/setting [POST]
|
||||
func (s *APIV1Service) CreateSystemSetting(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
@ -46,7 +45,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) {
|
|||
// @Router /api/v1/tag [GET]
|
||||
func (s *APIV1Service) GetTagList(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
|
||||
}
|
||||
|
@ -80,7 +79,7 @@ func (s *APIV1Service) GetTagList(c echo.Context) error {
|
|||
// @Router /api/v1/tag [POST]
|
||||
func (s *APIV1Service) CreateTag(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -122,7 +121,7 @@ func (s *APIV1Service) CreateTag(c echo.Context) error {
|
|||
// @Router /api/v1/tag/delete [POST]
|
||||
func (s *APIV1Service) DeleteTag(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -157,7 +156,7 @@ func (s *APIV1Service) DeleteTag(c echo.Context) error {
|
|||
// @Router /api/v1/tag/suggestion [GET]
|
||||
func (s *APIV1Service) GetTagSuggestion(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Missing user session")
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
@ -119,7 +118,7 @@ func (s *APIV1Service) GetUserList(c echo.Context) error {
|
|||
// @Router /api/v1/user [POST]
|
||||
func (s *APIV1Service) CreateUser(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
||||
}
|
||||
|
@ -184,7 +183,7 @@ func (s *APIV1Service) CreateUser(c echo.Context) error {
|
|||
// @Router /api/v1/user/me [GET]
|
||||
func (s *APIV1Service) GetCurrentUser(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
||||
}
|
||||
|
@ -287,7 +286,7 @@ func (s *APIV1Service) GetUserByID(c echo.Context) error {
|
|||
// @Router /api/v1/user/{id} [DELETE]
|
||||
func (s *APIV1Service) DeleteUser(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
currentUserID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
@ -337,7 +336,7 @@ func (s *APIV1Service) UpdateUser(c echo.Context) error {
|
|||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
|
||||
}
|
||||
|
||||
currentUserID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
currentUserID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
@ -97,7 +96,7 @@ func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) {
|
|||
// @Router /api/v1/user/setting [POST]
|
||||
func (s *APIV1Service) UpsertUserSetting(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(auth.UserIDContextKey).(int32)
|
||||
userID, ok := c.Get(userIDContextKey).(int32)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
||||
}
|
||||
|
|
|
@ -3,14 +3,11 @@ package v2
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
"github.com/usememos/memos/store"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -22,9 +19,9 @@ import (
|
|||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store user id in the context
|
||||
// The key name used to store username in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
UserIDContextKey ContextKey = iota
|
||||
usernameContextKey ContextKey = iota
|
||||
)
|
||||
|
||||
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
|
||||
|
@ -52,7 +49,7 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
|
|||
return nil, status.Errorf(codes.Unauthenticated, err.Error())
|
||||
}
|
||||
|
||||
userID, err := in.authenticate(ctx, accessTokenStr)
|
||||
username, err := in.authenticate(ctx, accessTokenStr)
|
||||
if err != nil {
|
||||
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
|
@ -60,28 +57,28 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
|
|||
return nil, err
|
||||
}
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID)
|
||||
return nil, errors.Errorf("user %q not exists", username)
|
||||
}
|
||||
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "user ID %q is not admin", userID)
|
||||
return nil, errors.Errorf("user %q is not admin", username)
|
||||
}
|
||||
|
||||
// Stores userID into context.
|
||||
childCtx := context.WithValue(ctx, UserIDContextKey, userID)
|
||||
childCtx := context.WithValue(ctx, usernameContextKey, username)
|
||||
return handler(childCtx, request)
|
||||
}
|
||||
|
||||
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (int32, error) {
|
||||
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (string, error) {
|
||||
if accessTokenStr == "" {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
return "", status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
}
|
||||
claims := &claimsMessage{}
|
||||
claims := &auth.ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
|
@ -94,34 +91,31 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr
|
|||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
}
|
||||
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
|
||||
return 0, status.Errorf(codes.Unauthenticated,
|
||||
return "", status.Errorf(codes.Unauthenticated,
|
||||
"invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
|
||||
claims.Audience,
|
||||
auth.AccessTokenAudienceName,
|
||||
)
|
||||
}
|
||||
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject)
|
||||
}
|
||||
username := claims.Name
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID)
|
||||
return "", errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID)
|
||||
return "", errors.Errorf("user %q not exists in the access token", username)
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID)
|
||||
return "", errors.Errorf("user %q is archived", username)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func getTokenFromMetadata(md metadata.MD) (string, error) {
|
||||
|
@ -154,41 +148,3 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type claimsMessage struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token for web.
|
||||
func GenerateAccessToken(username string, userID int, secret string) (string, error) {
|
||||
expirationTime := time.Now().Add(auth.AccessTokenDuration)
|
||||
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
|
||||
}
|
||||
|
||||
func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
// Create the JWT claims, which includes the username and expiry time.
|
||||
claims := &claimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{aud},
|
||||
// In JWT, the expiry time is expressed as unix milliseconds.
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: auth.Issuer,
|
||||
Subject: strconv.Itoa(userID),
|
||||
},
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["kid"] = auth.KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -26,3 +28,17 @@ func convertRowStatusToStore(rowStatus apiv2pb.RowStatus) store.RowStatus {
|
|||
return store.Normal
|
||||
}
|
||||
}
|
||||
|
||||
func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) {
|
||||
username, ok := ctx.Value(usernameContextKey).(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := s.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
|
|
@ -42,9 +42,9 @@ func (s *MemoService) ListMemos(ctx context.Context, request *apiv2pb.ListMemosR
|
|||
memoFind.CreatedTsAfter = filter.CreatedTsAfter
|
||||
}
|
||||
}
|
||||
userIDPtr := ctx.Value(UserIDContextKey)
|
||||
user, _ := getCurrentUser(ctx, s.Store)
|
||||
// If the user is not authenticated, only public memos are visible.
|
||||
if userIDPtr == nil {
|
||||
if user == nil {
|
||||
memoFind.VisibilityList = []store.Visibility{store.Public}
|
||||
}
|
||||
if request.PageSize != 0 {
|
||||
|
@ -80,12 +80,14 @@ func (s *MemoService) GetMemo(ctx context.Context, request *apiv2pb.GetMemoReque
|
|||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.Visibility != store.Public {
|
||||
userIDPtr := ctx.Value(UserIDContextKey)
|
||||
if userIDPtr == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unauthenticated")
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
userID := userIDPtr.(int32)
|
||||
if memo.Visibility == store.Private && memo.CreatorID != userID {
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,22 +31,16 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI
|
|||
defaultSystemInfo := &apiv2pb.SystemInfo{}
|
||||
|
||||
// Get the database size if the user is a host.
|
||||
userIDPtr := ctx.Value(UserIDContextKey)
|
||||
if userIDPtr != nil {
|
||||
userID := userIDPtr.(int32)
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser != nil && currentUser.Role == store.RoleHost {
|
||||
fi, err := os.Stat(s.Profile.DSN)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user != nil && user.Role == store.RoleHost {
|
||||
fi, err := os.Stat(s.Profile.DSN)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err)
|
||||
}
|
||||
defaultSystemInfo.DbSize = fi.Size()
|
||||
return nil, status.Errorf(codes.Internal, "failed to get file info: %v", err)
|
||||
}
|
||||
defaultSystemInfo.DbSize = fi.Size()
|
||||
}
|
||||
|
||||
response := &apiv2pb.GetSystemInfoResponse{
|
||||
|
@ -56,12 +50,9 @@ func (s *SystemService) GetSystemInfo(ctx context.Context, _ *apiv2pb.GetSystemI
|
|||
}
|
||||
|
||||
func (s *SystemService) UpdateSystemInfo(ctx context.Context, request *apiv2pb.UpdateSystemInfoRequest) (*apiv2pb.UpdateSystemInfoResponse, error) {
|
||||
userID := ctx.Value(UserIDContextKey).(int32)
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
|
|
|
@ -5,11 +5,16 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/common/util"
|
||||
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slices"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -18,13 +23,15 @@ import (
|
|||
type UserService struct {
|
||||
apiv2pb.UnimplementedUserServiceServer
|
||||
|
||||
Store *store.Store
|
||||
Store *store.Store
|
||||
Secret string
|
||||
}
|
||||
|
||||
// NewUserService creates a new UserService.
|
||||
func NewUserService(store *store.Store) *UserService {
|
||||
func NewUserService(store *store.Store, secret string) *UserService {
|
||||
return &UserService{
|
||||
Store: store,
|
||||
Store: store,
|
||||
Secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,13 +47,10 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
|
|||
}
|
||||
|
||||
userMessage := convertUserFromStore(user)
|
||||
userIDPtr := ctx.Value(UserIDContextKey)
|
||||
if userIDPtr != nil {
|
||||
userID := userIDPtr.(int32)
|
||||
if userID != userMessage.Id {
|
||||
// Data desensitization.
|
||||
userMessage.OpenId = ""
|
||||
}
|
||||
currentUser, _ := getCurrentUser(ctx, s.Store)
|
||||
if currentUser == nil || currentUser.ID != user.ID {
|
||||
// Data desensitization.
|
||||
userMessage.OpenId = ""
|
||||
}
|
||||
|
||||
response := &apiv2pb.GetUserResponse{
|
||||
|
@ -56,14 +60,11 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
|
|||
}
|
||||
|
||||
func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
|
||||
userID := ctx.Value(UserIDContextKey).(int32)
|
||||
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
|
||||
if currentUser.Username != request.Username && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask) == 0 {
|
||||
|
@ -72,7 +73,7 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse
|
|||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateUser{
|
||||
ID: userID,
|
||||
ID: currentUser.ID,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, path := range request.UpdateMask {
|
||||
|
@ -116,6 +117,162 @@ func (s *UserService) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUse
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil || user.Username != request.Username {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
|
||||
accessTokens := []*apiv2pb.UserAccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
claims := &auth.ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
// If the access token is invalid or expired, just ignore it.
|
||||
continue
|
||||
}
|
||||
|
||||
userAccessToken := &apiv2pb.UserAccessToken{
|
||||
AccessToken: userAccessToken.AccessToken,
|
||||
Description: userAccessToken.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
accessTokens = append(accessTokens, userAccessToken)
|
||||
}
|
||||
|
||||
// Sort by issued time in descending order.
|
||||
slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) bool {
|
||||
return i.IssuedAt.Seconds > j.IssuedAt.Seconds
|
||||
})
|
||||
response := &apiv2pb.ListUserAccessTokensResponse{
|
||||
AccessTokens: accessTokens,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
claims := &auth.ClaimsMessage{}
|
||||
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
|
||||
}
|
||||
|
||||
// Upsert the access token to user setting store.
|
||||
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.UserAccessToken.Description); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
|
||||
}
|
||||
|
||||
userAccessToken := &apiv2pb.UserAccessToken{
|
||||
AccessToken: accessToken,
|
||||
Description: request.UserAccessToken.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
response := &apiv2pb.CreateUserAccessTokenResponse{
|
||||
AccessToken: userAccessToken,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
if userAccessToken.AccessToken == request.AccessToken {
|
||||
continue
|
||||
}
|
||||
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
|
||||
}
|
||||
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: updatedUserAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user access tokens")
|
||||
}
|
||||
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
|
||||
AccessToken: accessToken,
|
||||
Description: description,
|
||||
}
|
||||
userAccessTokens = append(userAccessTokens, &userAccessToken)
|
||||
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: userAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert user setting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertUserFromStore(user *store.User) *apiv2pb.User {
|
||||
return &apiv2pb.User{
|
||||
Id: int32(user.ID),
|
||||
|
|
|
@ -30,7 +30,7 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store
|
|||
),
|
||||
)
|
||||
apiv2pb.RegisterSystemServiceServer(grpcServer, NewSystemService(profile, store))
|
||||
apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store))
|
||||
apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store, secret))
|
||||
apiv2pb.RegisterMemoServiceServer(grpcServer, NewMemoService(store))
|
||||
apiv2pb.RegisterTagServiceServer(grpcServer, NewTagService(store))
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@
|
|||
| Name | Number | Description |
|
||||
| ---- | ------ | ----------- |
|
||||
| USER_SETTING_KEY_UNSPECIFIED | 0 | |
|
||||
| USER_SETTING_ACCESS_TOKENS | 1 | Access tokens for the user. |
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -24,15 +24,19 @@ type UserSettingKey int32
|
|||
|
||||
const (
|
||||
UserSettingKey_USER_SETTING_KEY_UNSPECIFIED UserSettingKey = 0
|
||||
// Access tokens for the user.
|
||||
UserSettingKey_USER_SETTING_ACCESS_TOKENS UserSettingKey = 1
|
||||
)
|
||||
|
||||
// Enum value maps for UserSettingKey.
|
||||
var (
|
||||
UserSettingKey_name = map[int32]string{
|
||||
0: "USER_SETTING_KEY_UNSPECIFIED",
|
||||
1: "USER_SETTING_ACCESS_TOKENS",
|
||||
}
|
||||
UserSettingKey_value = map[string]int32{
|
||||
"USER_SETTING_KEY_UNSPECIFIED": 0,
|
||||
"USER_SETTING_ACCESS_TOKENS": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -279,10 +283,12 @@ var file_store_user_setting_proto_rawDesc = []byte{
|
|||
0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61,
|
||||
0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65,
|
||||
0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x32, 0x0a, 0x0e,
|
||||
0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2a, 0x52, 0x0a, 0x0e,
|
||||
0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x4b, 0x65, 0x79, 0x12, 0x20,
|
||||
0x0a, 0x1c, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47, 0x5f, 0x4b,
|
||||
0x45, 0x59, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00,
|
||||
0x12, 0x1e, 0x0a, 0x1a, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x53, 0x45, 0x54, 0x54, 0x49, 0x4e, 0x47,
|
||||
0x5f, 0x41, 0x43, 0x43, 0x45, 0x53, 0x53, 0x5f, 0x54, 0x4f, 0x4b, 0x45, 0x4e, 0x53, 0x10, 0x01,
|
||||
0x42, 0x9b, 0x01, 0x0a, 0x0f, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x65, 0x6d, 0x6f, 0x73, 0x2e, 0x73,
|
||||
0x74, 0x6f, 0x72, 0x65, 0x42, 0x10, 0x55, 0x73, 0x65, 0x72, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e,
|
||||
0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
|
||||
|
|
|
@ -16,6 +16,9 @@ message UserSetting {
|
|||
|
||||
enum UserSettingKey {
|
||||
USER_SETTING_KEY_UNSPECIFIED = 0;
|
||||
|
||||
// Access tokens for the user.
|
||||
USER_SETTING_ACCESS_TOKENS = 1;
|
||||
}
|
||||
|
||||
message AccessTokensUserSetting {
|
||||
|
|
|
@ -3,7 +3,11 @@ package store
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
type UserSetting struct {
|
||||
|
@ -102,6 +106,138 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use
|
|||
return userSetting, nil
|
||||
}
|
||||
|
||||
type FindUserSettingV1 struct {
|
||||
UserID *int32
|
||||
Key storepb.UserSettingKey
|
||||
}
|
||||
|
||||
func (s *Store) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userSettingMessage := upsert
|
||||
s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage)
|
||||
return userSettingMessage, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) ([]*storepb.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*storepb.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &storepb.UserSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserId,
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
|
||||
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Value = &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: accessTokensUserSetting,
|
||||
}
|
||||
} else {
|
||||
// Skip unknown user setting v1 key.
|
||||
continue
|
||||
}
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, userSetting := range userSettingList {
|
||||
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
||||
}
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) (*storepb.UserSetting, error) {
|
||||
if find.UserID != nil {
|
||||
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok {
|
||||
return cache.(*storepb.UserSetting), nil
|
||||
}
|
||||
}
|
||||
|
||||
list, err := s.ListUserSettingsV1(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
userSetting := list[0]
|
||||
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
||||
return userSetting, nil
|
||||
}
|
||||
|
||||
// GetUserAccessTokens returns the access tokens of the user.
|
||||
func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
|
||||
userSetting, err := s.GetUserSettingV1(ctx, &FindUserSettingV1{
|
||||
UserID: &userID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userSetting == nil {
|
||||
return []*storepb.AccessTokensUserSetting_AccessToken{}, nil
|
||||
}
|
||||
|
||||
accessTokensUserSetting := userSetting.GetAccessTokens()
|
||||
return accessTokensUserSetting.AccessTokens, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
|
|
|
@ -13,7 +13,14 @@ export declare enum UserSettingKey {
|
|||
/**
|
||||
* @generated from enum value: USER_SETTING_KEY_UNSPECIFIED = 0;
|
||||
*/
|
||||
UNSPECIFIED = 0,
|
||||
USER_SETTING_KEY_UNSPECIFIED = 0,
|
||||
|
||||
/**
|
||||
* Access tokens for the user.
|
||||
*
|
||||
* @generated from enum value: USER_SETTING_ACCESS_TOKENS = 1;
|
||||
*/
|
||||
USER_SETTING_ACCESS_TOKENS = 1,
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -11,7 +11,8 @@ import { proto3 } from "@bufbuild/protobuf";
|
|||
export const UserSettingKey = proto3.makeEnum(
|
||||
"memos.store.UserSettingKey",
|
||||
[
|
||||
{no: 0, name: "USER_SETTING_KEY_UNSPECIFIED", localName: "UNSPECIFIED"},
|
||||
{no: 0, name: "USER_SETTING_KEY_UNSPECIFIED"},
|
||||
{no: 1, name: "USER_SETTING_ACCESS_TOKENS"},
|
||||
],
|
||||
);
|
||||
|
||||
|
|
Loading…
Reference in a new issue