diff --git a/api/v2/auth_service.go b/api/v2/auth_service.go index 25a43a4a..3ce83eca 100644 --- a/api/v2/auth_service.go +++ b/api/v2/auth_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "regexp" + "strings" "time" "github.com/pkg/errors" @@ -12,14 +13,12 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" "github.com/usememos/memos/api/auth" "github.com/usememos/memos/internal/util" "github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp/oauth2" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" - storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/service/metric" "github.com/usememos/memos/store" ) @@ -31,7 +30,7 @@ func (s *APIV2Service) GetAuthStatus(ctx context.Context, _ *apiv2pb.GetAuthStat } if user == nil { // Set the cookie header to expire access token. - if err := clearAccessTokenCookie(ctx); err != nil { + if err := s.clearAccessTokenCookie(ctx); err != nil { return nil, status.Errorf(codes.Internal, "failed to set grpc header") } return nil, status.Errorf(codes.Unauthenticated, "user not found") @@ -61,8 +60,8 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques expireTime := time.Now().Add(auth.AccessTokenDuration) if request.NeverExpire { - // Zero time means never expire. - expireTime = time.Time{} + // Set the expire time to 100 years. + expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) } if err := s.doSignIn(ctx, user, expireTime); err != nil { return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) @@ -159,13 +158,12 @@ func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTim return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err)) } - cookieExpires := time.Now().Add(auth.CookieExpDuration) - if expireTime.IsZero() { - // Set cookie expires to 100 years. - cookieExpires = time.Now().AddDate(100, 0, 0) + cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime) + if err != nil { + return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err)) } if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": fmt.Sprintf("%s=%s; Path=/; Expires=%s; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName, accessToken, cookieExpires.Format(time.RFC1123)), + "Set-Cookie": cookie, })); err != nil { return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) } @@ -222,34 +220,46 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques }, nil } -func (*APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) { - if err := clearAccessTokenCookie(ctx); err != nil { +func (s *APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) { + if err := s.clearAccessTokenCookie(ctx); err != nil { return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) } return &apiv2pb.SignOutResponse{}, nil } -func clearAccessTokenCookie(ctx context.Context) error { +func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error { + cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{}) + if err != nil { + return errors.Wrap(err, "failed to build access token cookie") + } if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": fmt.Sprintf("%s=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName), + "Set-Cookie": cookie, })); err != nil { return errors.Wrap(err, "failed to set grpc header") } return nil } -func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) { - workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{ - Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(), - }) +func (s *APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { + attrs := []string{ + fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken), + "Path=/", + "HttpOnly", + } + if expireTime.IsZero() { + attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT") + } else { + attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123)) + } + workspaceGeneralSetting, err := s.GetWorkspaceGeneralSetting(ctx) if err != nil { - return nil, errors.Wrap(err, "failed to get workspace setting") + return "", errors.Wrap(err, "failed to get workspace setting") } - workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{} - if workspaceSetting != nil { - if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal workspace setting") - } + if workspaceGeneralSetting.InstanceUrl != "" && strings.HasPrefix(workspaceGeneralSetting.InstanceUrl, "https://") { + attrs = append(attrs, "SameSite=None") + attrs = append(attrs, "Secure") + } else { + attrs = append(attrs, "SameSite=Strict") } - return workspaceGeneralSetting, nil + return strings.Join(attrs, "; "), nil } diff --git a/api/v2/workspace_service.go b/api/v2/workspace_service.go index 6e34d58b..65f5d195 100644 --- a/api/v2/workspace_service.go +++ b/api/v2/workspace_service.go @@ -6,8 +6,11 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "github.com/pkg/errors" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -88,3 +91,19 @@ func (s *APIV2Service) UpdateWorkspaceProfile(ctx context.Context, request *apiv WorkspaceProfile: workspaceProfileMessage.WorkspaceProfile, }, nil } + +func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) { + workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{ + Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to get workspace setting") + } + workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{} + if workspaceSetting != nil { + if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal workspace setting") + } + } + return workspaceGeneralSetting, nil +}