mirror of
https://github.com/usememos/memos.git
synced 2024-12-26 23:22:47 +08:00
chore: add cookie builder
This commit is contained in:
parent
46ea16ef7e
commit
434ef44f8c
2 changed files with 54 additions and 25 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -12,14 +13,12 @@ import (
|
|||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/usememos/memos/api/auth"
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/plugin/idp/oauth2"
|
||||
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/service/metric"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
@ -31,7 +30,7 @@ func (s *APIV2Service) GetAuthStatus(ctx context.Context, _ *apiv2pb.GetAuthStat
|
|||
}
|
||||
if user == nil {
|
||||
// Set the cookie header to expire access token.
|
||||
if err := clearAccessTokenCookie(ctx); err != nil {
|
||||
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to set grpc header")
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
|
@ -61,8 +60,8 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques
|
|||
|
||||
expireTime := time.Now().Add(auth.AccessTokenDuration)
|
||||
if request.NeverExpire {
|
||||
// Zero time means never expire.
|
||||
expireTime = time.Time{}
|
||||
// Set the expire time to 100 years.
|
||||
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
|
||||
}
|
||||
if err := s.doSignIn(ctx, user, expireTime); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
||||
|
@ -159,13 +158,12 @@ func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTim
|
|||
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
|
||||
}
|
||||
|
||||
cookieExpires := time.Now().Add(auth.CookieExpDuration)
|
||||
if expireTime.IsZero() {
|
||||
// Set cookie expires to 100 years.
|
||||
cookieExpires = time.Now().AddDate(100, 0, 0)
|
||||
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
|
||||
}
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": fmt.Sprintf("%s=%s; Path=/; Expires=%s; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName, accessToken, cookieExpires.Format(time.RFC1123)),
|
||||
"Set-Cookie": cookie,
|
||||
})); err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
||||
}
|
||||
|
@ -222,34 +220,46 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
|
||||
if err := clearAccessTokenCookie(ctx); err != nil {
|
||||
func (s *APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
|
||||
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
||||
}
|
||||
return &apiv2pb.SignOutResponse{}, nil
|
||||
}
|
||||
|
||||
func clearAccessTokenCookie(ctx context.Context) error {
|
||||
func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error {
|
||||
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to build access token cookie")
|
||||
}
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": fmt.Sprintf("%s=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName),
|
||||
"Set-Cookie": cookie,
|
||||
})); err != nil {
|
||||
return errors.Wrap(err, "failed to set grpc header")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) {
|
||||
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
|
||||
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(),
|
||||
})
|
||||
func (s *APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
|
||||
attrs := []string{
|
||||
fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken),
|
||||
"Path=/",
|
||||
"HttpOnly",
|
||||
}
|
||||
if expireTime.IsZero() {
|
||||
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
} else {
|
||||
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
||||
}
|
||||
workspaceGeneralSetting, err := s.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace setting")
|
||||
return "", errors.Wrap(err, "failed to get workspace setting")
|
||||
}
|
||||
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{}
|
||||
if workspaceSetting != nil {
|
||||
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal workspace setting")
|
||||
if workspaceGeneralSetting.InstanceUrl != "" && strings.HasPrefix(workspaceGeneralSetting.InstanceUrl, "https://") {
|
||||
attrs = append(attrs, "SameSite=None")
|
||||
attrs = append(attrs, "Secure")
|
||||
} else {
|
||||
attrs = append(attrs, "SameSite=Strict")
|
||||
}
|
||||
}
|
||||
return workspaceGeneralSetting, nil
|
||||
return strings.Join(attrs, "; "), nil
|
||||
}
|
||||
|
|
|
@ -6,8 +6,11 @@ import (
|
|||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
|
@ -88,3 +91,19 @@ func (s *APIV2Service) UpdateWorkspaceProfile(ctx context.Context, request *apiv
|
|||
WorkspaceProfile: workspaceProfileMessage.WorkspaceProfile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) {
|
||||
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
|
||||
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace setting")
|
||||
}
|
||||
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{}
|
||||
if workspaceSetting != nil {
|
||||
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal workspace setting")
|
||||
}
|
||||
}
|
||||
return workspaceGeneralSetting, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue