2024-04-28 00:44:29 +08:00
|
|
|
package v1
|
2023-11-30 20:58:36 +08:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2023-12-20 07:42:02 +08:00
|
|
|
"fmt"
|
2024-05-20 08:53:29 +08:00
|
|
|
"log/slog"
|
2024-01-29 23:12:02 +08:00
|
|
|
"regexp"
|
2024-02-05 23:28:29 +08:00
|
|
|
"strings"
|
2024-01-29 23:12:02 +08:00
|
|
|
"time"
|
2023-11-30 20:58:36 +08:00
|
|
|
|
2023-12-20 07:42:02 +08:00
|
|
|
"github.com/pkg/errors"
|
2024-01-29 23:12:02 +08:00
|
|
|
"golang.org/x/crypto/bcrypt"
|
2023-12-20 07:42:02 +08:00
|
|
|
"google.golang.org/grpc"
|
2023-11-30 21:52:02 +08:00
|
|
|
"google.golang.org/grpc/codes"
|
2023-12-20 07:42:02 +08:00
|
|
|
"google.golang.org/grpc/metadata"
|
2023-11-30 21:52:02 +08:00
|
|
|
"google.golang.org/grpc/status"
|
2024-04-27 22:02:15 +08:00
|
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
2023-11-30 21:52:02 +08:00
|
|
|
|
2024-01-29 23:12:02 +08:00
|
|
|
"github.com/usememos/memos/internal/util"
|
|
|
|
"github.com/usememos/memos/plugin/idp"
|
|
|
|
"github.com/usememos/memos/plugin/idp/oauth2"
|
2024-04-28 00:44:29 +08:00
|
|
|
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
2024-04-13 10:50:25 +08:00
|
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
2024-01-29 23:12:02 +08:00
|
|
|
"github.com/usememos/memos/store"
|
2023-11-30 20:58:36 +08:00
|
|
|
)
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusRequest) (*v1pb.User, error) {
|
2024-05-26 11:02:23 +08:00
|
|
|
user, err := s.GetCurrentUser(ctx)
|
2023-11-30 21:52:02 +08:00
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
|
2023-11-30 20:58:36 +08:00
|
|
|
}
|
2023-12-06 23:03:24 +08:00
|
|
|
if user == nil {
|
2023-12-20 07:42:02 +08:00
|
|
|
// Set the cookie header to expire access token.
|
2024-02-05 23:28:29 +08:00
|
|
|
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
2024-04-07 22:35:02 +08:00
|
|
|
return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
|
2023-12-20 07:42:02 +08:00
|
|
|
}
|
2023-12-06 23:03:24 +08:00
|
|
|
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
|
|
|
}
|
2024-04-27 22:02:15 +08:00
|
|
|
return convertUserFromStore(user), nil
|
2023-11-30 20:58:36 +08:00
|
|
|
}
|
2023-12-20 07:42:02 +08:00
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
|
2024-01-29 23:12:02 +08:00
|
|
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
|
|
|
Username: &request.Username,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", request.Username))
|
|
|
|
}
|
|
|
|
if user == nil {
|
|
|
|
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("user not found with username %s", request.Username))
|
|
|
|
} else if user.RowStatus == store.Archived {
|
|
|
|
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", request.Username))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Compare the stored hashed password, with the hashed version of the password that was received.
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
|
|
|
|
return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
|
|
|
|
}
|
|
|
|
|
2024-05-01 10:26:46 +08:00
|
|
|
expireTime := time.Now().Add(AccessTokenDuration)
|
2024-01-29 23:12:02 +08:00
|
|
|
if request.NeverExpire {
|
2024-02-05 23:28:29 +08:00
|
|
|
// Set the expire time to 100 years.
|
|
|
|
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
if err := s.doSignIn(ctx, user, expireTime); err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
|
|
|
}
|
2024-04-27 22:02:15 +08:00
|
|
|
return convertUserFromStore(user), nil
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
|
2024-04-17 08:56:52 +08:00
|
|
|
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
2024-01-29 23:12:02 +08:00
|
|
|
ID: &request.IdpId,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get identity provider, err: %s", err))
|
|
|
|
}
|
|
|
|
if identityProvider == nil {
|
|
|
|
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("identity provider not found with id %d", request.IdpId))
|
|
|
|
}
|
|
|
|
|
|
|
|
var userInfo *idp.IdentityProviderUserInfo
|
2024-04-13 10:50:25 +08:00
|
|
|
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
|
|
|
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
|
2024-01-29 23:12:02 +08:00
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
|
|
|
|
}
|
|
|
|
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to exchange token, err: %s", err))
|
|
|
|
}
|
|
|
|
userInfo, err = oauth2IdentityProvider.UserInfo(token)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get user info, err: %s", err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
identifierFilter := identityProvider.IdentifierFilter
|
|
|
|
if identifierFilter != "" {
|
|
|
|
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to compile identifier filter regex, err: %s", err))
|
|
|
|
}
|
|
|
|
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
|
|
|
|
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("identifier %s is not allowed", userInfo.Identifier))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
|
|
|
Username: &userInfo.Identifier,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", userInfo.Identifier))
|
|
|
|
}
|
|
|
|
if user == nil {
|
|
|
|
userCreate := &store.User{
|
|
|
|
Username: userInfo.Identifier,
|
|
|
|
// The new signup user should be normal user by default.
|
|
|
|
Role: store.RoleUser,
|
|
|
|
Nickname: userInfo.DisplayName,
|
|
|
|
Email: userInfo.Email,
|
|
|
|
}
|
|
|
|
password, err := util.RandomString(20)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate random password, err: %s", err))
|
|
|
|
}
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
|
|
|
|
}
|
|
|
|
userCreate.PasswordHash = string(passwordHash)
|
|
|
|
user, err = s.Store.CreateUser(ctx, userCreate)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if user.RowStatus == store.Archived {
|
|
|
|
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
|
|
|
|
}
|
|
|
|
|
2024-05-01 10:26:46 +08:00
|
|
|
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
|
2024-01-29 23:12:02 +08:00
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
|
|
|
}
|
2024-04-27 22:02:15 +08:00
|
|
|
return convertUserFromStore(user), nil
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
|
2024-05-01 10:26:46 +08:00
|
|
|
accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
|
2024-01-29 23:12:02 +08:00
|
|
|
if err != nil {
|
|
|
|
return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
|
|
|
|
}
|
|
|
|
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
|
|
|
|
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
|
|
|
|
}
|
|
|
|
|
2024-02-05 23:28:29 +08:00
|
|
|
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
|
|
|
|
if err != nil {
|
|
|
|
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
2024-02-05 23:28:29 +08:00
|
|
|
"Set-Cookie": cookie,
|
2024-01-29 23:12:02 +08:00
|
|
|
})); err != nil {
|
|
|
|
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) (*v1pb.User, error) {
|
2024-08-28 23:46:06 +08:00
|
|
|
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
2024-07-24 23:38:51 +08:00
|
|
|
if err != nil {
|
2024-08-28 23:46:06 +08:00
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace general setting, err: %s", err))
|
2024-07-24 23:38:51 +08:00
|
|
|
}
|
2024-08-28 23:46:06 +08:00
|
|
|
if workspaceGeneralSetting.DisallowSignup {
|
2024-01-29 23:12:02 +08:00
|
|
|
return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
|
|
|
|
}
|
|
|
|
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
|
|
|
|
}
|
|
|
|
|
|
|
|
create := &store.User{
|
|
|
|
Username: request.Username,
|
|
|
|
Nickname: request.Username,
|
|
|
|
PasswordHash: string(passwordHash),
|
|
|
|
}
|
2024-03-20 20:39:16 +08:00
|
|
|
if !util.UIDMatcher.MatchString(strings.ToLower(create.Username)) {
|
2024-03-15 08:37:58 +08:00
|
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", create.Username)
|
|
|
|
}
|
2024-01-31 19:55:52 +08:00
|
|
|
|
|
|
|
hostUserType := store.RoleHost
|
|
|
|
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
|
|
|
|
Role: &hostUserType,
|
|
|
|
})
|
2024-01-29 23:12:02 +08:00
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to list users, err: %s", err))
|
|
|
|
}
|
2024-01-31 19:55:52 +08:00
|
|
|
if len(existedHostUsers) == 0 {
|
|
|
|
// Change the default role to host if there is no host user.
|
|
|
|
create.Role = store.RoleHost
|
2024-01-29 23:12:02 +08:00
|
|
|
} else {
|
|
|
|
create.Role = store.RoleUser
|
|
|
|
}
|
|
|
|
|
|
|
|
user, err := s.Store.CreateUser(ctx, create)
|
|
|
|
if err != nil {
|
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
|
|
|
|
}
|
|
|
|
|
2024-05-01 10:26:46 +08:00
|
|
|
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
|
2024-01-29 23:12:02 +08:00
|
|
|
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
|
|
|
}
|
2024-04-27 22:02:15 +08:00
|
|
|
return convertUserFromStore(user), nil
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
|
2024-05-20 08:53:29 +08:00
|
|
|
accessToken, ok := ctx.Value(accessTokenContextKey).(string)
|
|
|
|
// Try to delete the access token from the store.
|
|
|
|
if ok {
|
2024-07-13 11:18:29 +08:00
|
|
|
user, _ := s.GetCurrentUser(ctx)
|
|
|
|
if user != nil {
|
|
|
|
if _, err := s.DeleteUserAccessToken(ctx, &v1pb.DeleteUserAccessTokenRequest{
|
|
|
|
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
|
|
|
|
AccessToken: accessToken,
|
|
|
|
}); err != nil {
|
2024-07-27 19:24:37 +08:00
|
|
|
slog.Error("failed to delete access token", "error", err)
|
2024-07-13 11:18:29 +08:00
|
|
|
}
|
2024-05-20 08:53:29 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-02-05 23:28:29 +08:00
|
|
|
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
2024-01-29 23:12:02 +08:00
|
|
|
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
|
|
|
}
|
2024-04-27 22:02:15 +08:00
|
|
|
return &emptypb.Empty{}, nil
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
|
2024-02-05 23:28:29 +08:00
|
|
|
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
|
|
|
|
if err != nil {
|
|
|
|
return errors.Wrap(err, "failed to build access token cookie")
|
|
|
|
}
|
2023-12-20 07:42:02 +08:00
|
|
|
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
2024-02-05 23:28:29 +08:00
|
|
|
"Set-Cookie": cookie,
|
2023-12-20 07:42:02 +08:00
|
|
|
})); err != nil {
|
|
|
|
return errors.Wrap(err, "failed to set grpc header")
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
2024-01-29 23:12:02 +08:00
|
|
|
|
2024-04-28 00:44:29 +08:00
|
|
|
func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
|
2024-02-05 23:28:29 +08:00
|
|
|
attrs := []string{
|
2024-05-01 10:26:46 +08:00
|
|
|
fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
|
2024-02-05 23:28:29 +08:00
|
|
|
"Path=/",
|
|
|
|
"HttpOnly",
|
|
|
|
}
|
|
|
|
if expireTime.IsZero() {
|
|
|
|
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
|
|
|
|
} else {
|
|
|
|
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
|
|
|
}
|
2024-03-03 14:10:48 +08:00
|
|
|
|
|
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
|
|
if !ok {
|
|
|
|
return "", errors.New("failed to get metadata from context")
|
|
|
|
}
|
|
|
|
var origin string
|
|
|
|
for _, v := range md.Get("origin") {
|
|
|
|
origin = v
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
2024-03-03 14:10:48 +08:00
|
|
|
isHTTPS := strings.HasPrefix(origin, "https://")
|
|
|
|
if isHTTPS {
|
2024-02-05 23:28:29 +08:00
|
|
|
attrs = append(attrs, "SameSite=None")
|
|
|
|
attrs = append(attrs, "Secure")
|
|
|
|
} else {
|
|
|
|
attrs = append(attrs, "SameSite=Strict")
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
2024-02-05 23:28:29 +08:00
|
|
|
return strings.Join(attrs, "; "), nil
|
2024-01-29 23:12:02 +08:00
|
|
|
}
|
2024-05-26 11:02:23 +08:00
|
|
|
|
|
|
|
func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) {
|
|
|
|
username, ok := ctx.Value(usernameContextKey).(string)
|
|
|
|
if !ok {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
|
|
|
Username: &username,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return user, nil
|
|
|
|
}
|