chore: add cookie builder

This commit is contained in:
Steven 2024-02-05 23:28:29 +08:00
parent 46ea16ef7e
commit 434ef44f8c
2 changed files with 54 additions and 25 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
@ -12,14 +13,12 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/usememos/memos/api/auth"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/service/metric"
"github.com/usememos/memos/store"
)
@ -31,7 +30,7 @@ func (s *APIV2Service) GetAuthStatus(ctx context.Context, _ *apiv2pb.GetAuthStat
}
if user == nil {
// Set the cookie header to expire access token.
if err := clearAccessTokenCookie(ctx); err != nil {
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header")
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
@ -61,8 +60,8 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques
expireTime := time.Now().Add(auth.AccessTokenDuration)
if request.NeverExpire {
// Zero time means never expire.
expireTime = time.Time{}
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, user, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
@ -159,13 +158,12 @@ func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTim
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
}
cookieExpires := time.Now().Add(auth.CookieExpDuration)
if expireTime.IsZero() {
// Set cookie expires to 100 years.
cookieExpires = time.Now().AddDate(100, 0, 0)
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": fmt.Sprintf("%s=%s; Path=/; Expires=%s; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName, accessToken, cookieExpires.Format(time.RFC1123)),
"Set-Cookie": cookie,
})); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
@ -222,34 +220,46 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques
}, nil
}
func (*APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
if err := clearAccessTokenCookie(ctx); err != nil {
func (s *APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return &apiv2pb.SignOutResponse{}, nil
}
func clearAccessTokenCookie(ctx context.Context) error {
func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error {
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build access token cookie")
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": fmt.Sprintf("%s=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName),
"Set-Cookie": cookie,
})); err != nil {
return errors.Wrap(err, "failed to set grpc header")
}
return nil
}
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) {
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(),
})
func (s *APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
}
workspaceGeneralSetting, err := s.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace setting")
return "", errors.Wrap(err, "failed to get workspace setting")
}
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{}
if workspaceSetting != nil {
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal workspace setting")
}
if workspaceGeneralSetting.InstanceUrl != "" && strings.HasPrefix(workspaceGeneralSetting.InstanceUrl, "https://") {
attrs = append(attrs, "SameSite=None")
attrs = append(attrs, "Secure")
} else {
attrs = append(attrs, "SameSite=Strict")
}
return workspaceGeneralSetting, nil
return strings.Join(attrs, "; "), nil
}

View file

@ -6,8 +6,11 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pkg/errors"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -88,3 +91,19 @@ func (s *APIV2Service) UpdateWorkspaceProfile(ctx context.Context, request *apiv
WorkspaceProfile: workspaceProfileMessage.WorkspaceProfile,
}, nil
}
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) {
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(),
})
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace setting")
}
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{}
if workspaceSetting != nil {
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal workspace setting")
}
}
return workspaceGeneralSetting, nil
}