mirror of
https://github.com/usememos/memos.git
synced 2024-12-26 23:22:47 +08:00
chore: remove invalid access token from db (#2539)
Remove invalid access token from db
This commit is contained in:
parent
e5f660a006
commit
91296257fc
3 changed files with 47 additions and 25 deletions
|
@ -254,33 +254,14 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error {
|
||||||
// @Success 200 {boolean} true "Sign-out success"
|
// @Success 200 {boolean} true "Sign-out success"
|
||||||
// @Router /api/v1/auth/signout [POST]
|
// @Router /api/v1/auth/signout [POST]
|
||||||
func (s *APIV1Service) SignOut(c echo.Context) error {
|
func (s *APIV1Service) SignOut(c echo.Context) error {
|
||||||
ctx := c.Request().Context()
|
|
||||||
accessToken := findAccessToken(c)
|
accessToken := findAccessToken(c)
|
||||||
userID, _ := getUserIDFromAccessToken(accessToken, s.Secret)
|
userID, _ := getUserIDFromAccessToken(accessToken, s.Secret)
|
||||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
|
|
||||||
// Auto remove the current access token from the user access tokens.
|
|
||||||
if err == nil && len(userAccessTokens) != 0 {
|
|
||||||
accessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
|
|
||||||
for _, userAccessToken := range userAccessTokens {
|
|
||||||
if accessToken != userAccessToken.AccessToken {
|
|
||||||
accessTokens = append(accessTokens, userAccessToken)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{
|
err := removeAccessTokenAndCookies(c, s.Store, userID, accessToken)
|
||||||
UserId: userID,
|
if err != nil {
|
||||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to remove access token, err: %s", err)).SetInternal(err)
|
||||||
Value: &storepb.UserSetting_AccessTokens{
|
|
||||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
|
||||||
AccessTokens: accessTokens,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert user setting, err: %s", err)).SetInternal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
removeAccessTokenAndCookies(c)
|
|
||||||
return c.JSON(http.StatusOK, true)
|
return c.JSON(http.StatusOK, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -393,9 +374,15 @@ func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeAccessTokenAndCookies removes the jwt token from the cookies.
|
// removeAccessTokenAndCookies removes the jwt token from the cookies.
|
||||||
func removeAccessTokenAndCookies(c echo.Context) {
|
func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error {
|
||||||
|
err := s.RemoveUserAccessToken(c.Request().Context(), userID, token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
cookieExp := time.Now().Add(-1 * time.Hour)
|
cookieExp := time.Now().Add(-1 * time.Hour)
|
||||||
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
|
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setTokenCookie sets the token to the cookie.
|
// setTokenCookie sets the token to the cookie.
|
||||||
|
|
|
@ -8,8 +8,10 @@ import (
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/usememos/memos/api/auth"
|
"github.com/usememos/memos/api/auth"
|
||||||
|
"github.com/usememos/memos/internal/log"
|
||||||
"github.com/usememos/memos/internal/util"
|
"github.com/usememos/memos/internal/util"
|
||||||
storepb "github.com/usememos/memos/proto/gen/store"
|
storepb "github.com/usememos/memos/proto/gen/store"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
|
@ -79,7 +81,10 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
||||||
|
|
||||||
userID, err := getUserIDFromAccessToken(accessToken, secret)
|
userID, err := getUserIDFromAccessToken(accessToken, secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
removeAccessTokenAndCookies(c)
|
err = removeAccessTokenAndCookies(c, server.Store, userID, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("fail to remove AccessToken and Cookies", zap.Error(err))
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +93,10 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err)
|
||||||
}
|
}
|
||||||
if !validateAccessToken(accessToken, accessTokens) {
|
if !validateAccessToken(accessToken, accessTokens) {
|
||||||
removeAccessTokenAndCookies(c)
|
err = removeAccessTokenAndCookies(c, server.Store, userID, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("fail to remove AccessToken and Cookies", zap.Error(err))
|
||||||
|
}
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,3 +123,30 @@ func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*store
|
||||||
accessTokensUserSetting := userSetting.GetAccessTokens()
|
accessTokensUserSetting := userSetting.GetAccessTokens()
|
||||||
return accessTokensUserSetting.AccessTokens, nil
|
return accessTokensUserSetting.AccessTokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveUserAccessToken remove the access token of the user.
|
||||||
|
func (s *Store) RemoveUserAccessToken(ctx context.Context, userID int32, token string) error {
|
||||||
|
oldAccessTokens, err := s.GetUserAccessTokens(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newAccessTokens := make([]*storepb.AccessTokensUserSetting_AccessToken, 0, len(oldAccessTokens))
|
||||||
|
for _, t := range oldAccessTokens {
|
||||||
|
if token != t.AccessToken {
|
||||||
|
newAccessTokens = append(newAccessTokens, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.UpsertUserSettingV1(ctx, &storepb.UserSetting{
|
||||||
|
UserId: userID,
|
||||||
|
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||||
|
Value: &storepb.UserSetting_AccessTokens{
|
||||||
|
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||||
|
AccessTokens: newAccessTokens,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue