diff --git a/store/user_setting.go b/store/user_setting.go index fcdb0a951..9d2a8c3e5 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -27,6 +27,12 @@ type UserSessionQueryResult struct { Session *storepb.SessionsUserSetting_Session } +// RefreshTokenQueryResult contains the result of querying a refresh token. +type RefreshTokenQueryResult struct { + UserID int32 + RefreshToken *storepb.RefreshTokensUserSetting_RefreshToken +} + func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { userSettingRaw, err := convertUserSettingToRaw(upsert) if err != nil { @@ -253,6 +259,82 @@ func (s *Store) GetUserSessionByID(ctx context.Context, sessionID string) (*User return s.driver.GetUserSessionByID(ctx, sessionID) } +// GetUserRefreshTokens returns the refresh tokens of the user. +func (s *Store) GetUserRefreshTokens(ctx context.Context, userID int32) ([]*storepb.RefreshTokensUserSetting_RefreshToken, error) { + userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ + UserID: &userID, + Key: storepb.UserSetting_REFRESH_TOKENS, + }) + if err != nil { + return nil, err + } + if userSetting == nil { + return []*storepb.RefreshTokensUserSetting_RefreshToken{}, nil + } + return userSetting.GetRefreshTokens().RefreshTokens, nil +} + +// AddUserRefreshToken adds a new refresh token for the user. +func (s *Store) AddUserRefreshToken(ctx context.Context, userID int32, token *storepb.RefreshTokensUserSetting_RefreshToken) error { + existingTokens, err := s.GetUserRefreshTokens(ctx, userID) + if err != nil { + return err + } + + tokens := append(existingTokens, token) + + _, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: userID, + Key: storepb.UserSetting_REFRESH_TOKENS, + Value: &storepb.UserSetting_RefreshTokens{ + RefreshTokens: &storepb.RefreshTokensUserSetting{ + RefreshTokens: tokens, + }, + }, + }) + return err +} + +// RemoveUserRefreshToken removes a refresh token from the user. +func (s *Store) RemoveUserRefreshToken(ctx context.Context, userID int32, tokenID string) error { + existingTokens, err := s.GetUserRefreshTokens(ctx, userID) + if err != nil { + return err + } + + newTokens := make([]*storepb.RefreshTokensUserSetting_RefreshToken, 0, len(existingTokens)) + for _, token := range existingTokens { + if token.TokenId != tokenID { + newTokens = append(newTokens, token) + } + } + + _, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: userID, + Key: storepb.UserSetting_REFRESH_TOKENS, + Value: &storepb.UserSetting_RefreshTokens{ + RefreshTokens: &storepb.RefreshTokensUserSetting{ + RefreshTokens: newTokens, + }, + }, + }) + return err +} + +// GetUserRefreshTokenByID returns a specific refresh token. +func (s *Store) GetUserRefreshTokenByID(ctx context.Context, userID int32, tokenID string) (*storepb.RefreshTokensUserSetting_RefreshToken, error) { + tokens, err := s.GetUserRefreshTokens(ctx, userID) + if err != nil { + return nil, err + } + for _, token := range tokens { + if token.TokenId == tokenID { + return token, nil + } + } + return nil, nil +} + // GetUserWebhooks returns the webhooks of the user. func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) { userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ @@ -392,6 +474,12 @@ func convertUserSettingFromRaw(raw *UserSetting) (*storepb.UserSetting, error) { return nil, err } userSetting.Value = &storepb.UserSetting_General{General: generalUserSetting} + case storepb.UserSetting_REFRESH_TOKENS: + refreshTokensUserSetting := &storepb.RefreshTokensUserSetting{} + if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), refreshTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_RefreshTokens{RefreshTokens: refreshTokensUserSetting} case storepb.UserSetting_WEBHOOKS: webhooksUserSetting := &storepb.WebhooksUserSetting{} if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), webhooksUserSetting); err != nil { @@ -439,6 +527,13 @@ func convertUserSettingToRaw(userSetting *storepb.UserSetting) (*UserSetting, er return nil, err } raw.Value = string(value) + case storepb.UserSetting_REFRESH_TOKENS: + refreshTokensUserSetting := userSetting.GetRefreshTokens() + value, err := protojson.Marshal(refreshTokensUserSetting) + if err != nil { + return nil, err + } + raw.Value = string(value) case storepb.UserSetting_WEBHOOKS: webhooksUserSetting := userSetting.GetWebhooks() value, err := protojson.Marshal(webhooksUserSetting)