feat: implement user sessions

This commit is contained in:
Johnny 2025-06-23 09:13:51 +08:00
parent 6e4d1d9100
commit 4e3a4e36f6
14 changed files with 705 additions and 125 deletions

View file

@ -511,9 +511,6 @@ message UserSession {
// Optional. Browser name and version (e.g., "Chrome 119.0").
string browser = 5 [(google.api.field_behavior) = OPTIONAL];
// Optional. Geographic location (country code, e.g., "US").
string country = 6 [(google.api.field_behavior) = OPTIONAL];
}
}

View file

@ -1868,9 +1868,7 @@ type UserSession_ClientInfo struct {
// Optional. Operating system (e.g., "iOS 17.0", "Windows 11").
Os string `protobuf:"bytes,4,opt,name=os,proto3" json:"os,omitempty"`
// Optional. Browser name and version (e.g., "Chrome 119.0").
Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"`
// Optional. Geographic location (country code, e.g., "US").
Country string `protobuf:"bytes,6,opt,name=country,proto3" json:"country,omitempty"`
Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -1940,13 +1938,6 @@ func (x *UserSession_ClientInfo) GetBrowser() string {
return ""
}
func (x *UserSession_ClientInfo) GetCountry() string {
if x != nil {
return x.Country
}
return ""
}
var File_api_v1_user_service_proto protoreflect.FileDescriptor
const file_api_v1_user_service_proto_rawDesc = "" +
@ -2084,7 +2075,7 @@ const file_api_v1_user_service_proto_rawDesc = "" +
"\x0faccess_token_id\x18\x03 \x01(\tB\x03\xe0A\x01R\raccessTokenId\"X\n" +
"\x1cDeleteUserAccessTokenRequest\x128\n" +
"\x04name\x18\x01 \x01(\tB$\xe0A\x02\xfaA\x1e\n" +
"\x1cmemos.api.v1/UserAccessTokenR\x04name\"\xf5\x04\n" +
"\x1cmemos.api.v1/UserAccessTokenR\x04name\"\xd6\x04\n" +
"\vUserSession\x12\x17\n" +
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12\"\n" +
"\n" +
@ -2095,7 +2086,7 @@ const file_api_v1_user_service_proto_rawDesc = "" +
"expireTime\x12M\n" +
"\x12last_accessed_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\x10lastAccessedTime\x12J\n" +
"\vclient_info\x18\x06 \x01(\v2$.memos.api.v1.UserSession.ClientInfoB\x03\xe0A\x03R\n" +
"clientInfo\x1a\xc3\x01\n" +
"clientInfo\x1a\xa4\x01\n" +
"\n" +
"ClientInfo\x12\x1d\n" +
"\n" +
@ -2105,8 +2096,7 @@ const file_api_v1_user_service_proto_rawDesc = "" +
"\vdevice_type\x18\x03 \x01(\tB\x03\xe0A\x01R\n" +
"deviceType\x12\x13\n" +
"\x02os\x18\x04 \x01(\tB\x03\xe0A\x01R\x02os\x12\x1d\n" +
"\abrowser\x18\x05 \x01(\tB\x03\xe0A\x01R\abrowser\x12\x1d\n" +
"\acountry\x18\x06 \x01(\tB\x03\xe0A\x01R\acountry:D\xeaAA\n" +
"\abrowser\x18\x05 \x01(\tB\x03\xe0A\x01R\abrowser:D\xeaAA\n" +
"\x18memos.api.v1/UserSession\x12\x1fusers/{user}/sessions/{session}\x1a\x04name\"L\n" +
"\x17ListUserSessionsRequest\x121\n" +
"\x06parent\x18\x01 \x01(\tB\x19\xe0A\x02\xfaA\x13\n" +

View file

@ -4340,9 +4340,6 @@ definitions:
browser:
type: string
description: Optional. Browser name and version (e.g., "Chrome 119.0").
country:
type: string
description: Optional. Geographic location (country code, e.g., "US").
v1UserStats:
type: object
properties:

View file

@ -590,9 +590,7 @@ type SessionsUserSetting_ClientInfo struct {
// Optional. Operating system (e.g., "iOS 17.0", "Windows 11").
Os string `protobuf:"bytes,4,opt,name=os,proto3" json:"os,omitempty"`
// Optional. Browser name and version (e.g., "Chrome 119.0").
Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"`
// Optional. Geographic location (country code, e.g., "US").
Country string `protobuf:"bytes,6,opt,name=country,proto3" json:"country,omitempty"`
Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@ -662,13 +660,6 @@ func (x *SessionsUserSetting_ClientInfo) GetBrowser() string {
return ""
}
func (x *SessionsUserSetting_ClientInfo) GetCountry() string {
if x != nil {
return x.Country
}
return ""
}
var File_store_user_setting_proto protoreflect.FileDescriptor
const file_store_user_setting_proto_rawDesc = "" +
@ -696,7 +687,7 @@ const file_store_user_setting_proto_rawDesc = "" +
"\bShortcut\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" +
"\x05title\x18\x02 \x01(\tR\x05title\x12\x16\n" +
"\x06filter\x18\x03 \x01(\tR\x06filter\"\xca\x04\n" +
"\x06filter\x18\x03 \x01(\tR\x06filter\"\xb0\x04\n" +
"\x13SessionsUserSetting\x12D\n" +
"\bsessions\x18\x01 \x03(\v2(.memos.store.SessionsUserSetting.SessionR\bsessions\x1a\xba\x02\n" +
"\aSession\x12\x1d\n" +
@ -708,7 +699,7 @@ const file_store_user_setting_proto_rawDesc = "" +
"expireTime\x12H\n" +
"\x12last_accessed_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\x10lastAccessedTime\x12L\n" +
"\vclient_info\x18\x05 \x01(\v2+.memos.store.SessionsUserSetting.ClientInfoR\n" +
"clientInfo\x1a\xaf\x01\n" +
"clientInfo\x1a\x95\x01\n" +
"\n" +
"ClientInfo\x12\x1d\n" +
"\n" +
@ -718,8 +709,7 @@ const file_store_user_setting_proto_rawDesc = "" +
"\vdevice_type\x18\x03 \x01(\tR\n" +
"deviceType\x12\x0e\n" +
"\x02os\x18\x04 \x01(\tR\x02os\x12\x18\n" +
"\abrowser\x18\x05 \x01(\tR\abrowser\x12\x18\n" +
"\acountry\x18\x06 \x01(\tR\acountry*\x93\x01\n" +
"\abrowser\x18\x05 \x01(\tR\abrowser*\x93\x01\n" +
"\x0eUserSettingKey\x12 \n" +
"\x1cUSER_SETTING_KEY_UNSPECIFIED\x10\x00\x12\x11\n" +
"\rACCESS_TOKENS\x10\x01\x12\n" +

View file

@ -80,8 +80,6 @@ message SessionsUserSetting {
string os = 4;
// Optional. Browser name and version (e.g., "Chrome 119.0").
string browser = 5;
// Optional. Geographic location (country code, e.g., "US").
string country = 6;
}
repeated Session sessions = 1;

View file

@ -52,22 +52,38 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
// Try to get access token from either Authorization header or cookie
accessToken, err := getTokenFromMetadata(md)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get access token: %v", err)
}
// Authenticate using access token (which also validates sessions when it's from cookie)
user, err := in.authenticateByAccessToken(ctx, accessToken)
if err != nil {
// Check if this method is in the allowlist first
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
// Try to authenticate via session ID (from cookie) first
if sessionCookieValue, err := getSessionIDFromMetadata(md); err == nil && sessionCookieValue != "" {
user, err := in.authenticateBySession(ctx, sessionCookieValue)
if err == nil && user != nil {
// Extract just the sessionID part for context storage
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookieValue)
if parseErr != nil {
return nil, status.Errorf(codes.Internal, "failed to parse session cookie: %v", parseErr)
}
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, sessionID, "")
}
return nil, err
}
// Try to authenticate via JWT access token (from Authorization header)
if accessToken, err := getAccessTokenFromMetadata(md); err == nil && accessToken != "" {
user, err := in.authenticateByJWT(ctx, accessToken)
if err == nil && user != nil {
return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, "", accessToken)
}
}
// If no valid authentication found, check if this method is in the allowlist (public endpoints)
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
// If authentication is required but not found, reject the request
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
// handleAuthenticatedRequest processes an authenticated request with the given user and auth info.
func (in *GRPCAuthInterceptor) handleAuthenticatedRequest(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler, user *store.User, sessionID, accessToken string) (any, error) {
// Check user status
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", user.Username)
@ -79,22 +95,21 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
// Set context values
ctx = context.WithValue(ctx, userIDContextKey, user.ID)
// Determine if this came from cookie (session) or header (API token)
if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil {
// Came from cookie, treat as session
ctx = context.WithValue(ctx, sessionIDContextKey, accessToken)
if sessionID != "" {
// Session-based authentication
ctx = context.WithValue(ctx, sessionIDContextKey, sessionID)
// Update session last accessed time
_ = in.updateSessionLastAccessed(ctx, user.ID, accessToken)
} else {
// Came from Authorization header, treat as API token
_ = in.updateSessionLastAccessed(ctx, user.ID, sessionID)
} else if accessToken != "" {
// JWT access token-based authentication
ctx = context.WithValue(ctx, accessTokenContextKey, accessToken)
}
return handler(ctx, request)
}
// authenticateByAccessToken authenticates a user using access token from Authorization header or cookie.
func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (*store.User, error) {
// authenticateByJWT authenticates a user using JWT access token from Authorization header.
func (in *GRPCAuthInterceptor) authenticateByJWT(ctx context.Context, accessToken string) (*store.User, error) {
if accessToken == "" {
return nil, status.Errorf(codes.Unauthenticated, "access token not found")
}
@ -114,7 +129,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
// Get user from JWT claims
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return nil, errors.Wrap(err, "malformed ID in the token")
@ -132,6 +147,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
return nil, errors.Errorf("user %q is archived", userID)
}
// Validate that this access token exists in the user's access tokens
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, errors.Wrapf(err, "failed to get user access tokens")
@ -140,10 +156,43 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
}
// For tokens that might be used as session IDs (from cookies), also validate session existence
// This is a best-effort check - if sessions can't be retrieved or token isn't a session, that's ok
if sessions, err := in.Store.GetUserSessions(ctx, user.ID); err == nil {
validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens
return user, nil
}
// authenticateBySession authenticates a user using session ID from cookie.
func (in *GRPCAuthInterceptor) authenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) {
if sessionCookieValue == "" {
return nil, status.Errorf(codes.Unauthenticated, "session cookie value not found")
}
// Parse the cookie value to extract userID and sessionID
userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid session cookie format: %v", err)
}
// Get the user directly using the userID from the cookie
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.Unauthenticated, "user is archived")
}
// Get user sessions and validate the sessionID
sessions, err := in.Store.GetUserSessions(ctx, userID)
if err != nil {
return nil, errors.Wrap(err, "failed to get user sessions")
}
if !validateUserSession(sessionID, sessions) {
return nil, status.Errorf(codes.Unauthenticated, "invalid or expired session")
}
return user, nil
@ -168,6 +217,24 @@ func validateUserSession(sessionID string, userSessions []*storepb.SessionsUserS
return false
}
// getSessionIDFromMetadata extracts session cookie value from cookie.
func getSessionIDFromMetadata(md metadata.MD) (string, error) {
// Check the cookie header for session cookie value
var sessionCookieValue string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(SessionCookieName); v != nil {
sessionCookieValue = v.Value
}
}
if sessionCookieValue == "" {
return "", errors.New("session cookie not found")
}
return sessionCookieValue, nil
}
// getAccessTokenFromMetadata extracts access token from Authorization header.
func getAccessTokenFromMetadata(md metadata.MD) (string, error) {
// Check the HTTP request Authorization header.
@ -182,29 +249,6 @@ func getAccessTokenFromMetadata(md metadata.MD) (string, error) {
return authHeaderParts[1], nil
}
func getTokenFromMetadata(md metadata.MD) (string, error) {
// Check the HTTP request header first.
authorizationHeaders := md.Get("Authorization")
if len(authorizationHeaders) > 0 {
authHeaderParts := strings.Fields(authorizationHeaders[0])
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// Check the cookie header.
var accessToken string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(AccessTokenCookieName); v != nil {
accessToken = v.Value
}
}
return accessToken, nil
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {

View file

@ -2,9 +2,12 @@ package v1
import (
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/usememos/memos/internal/util"
)
const (
@ -20,8 +23,8 @@ const (
// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user
// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt.
CookieExpDuration = AccessTokenDuration - 1*time.Minute
// AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "memos.access-token"
// SessionCookieName is the cookie name of user session ID.
SessionCookieName = "user_session"
)
type ClaimsMessage struct {
@ -61,3 +64,28 @@ func generateToken(username string, userID int32, audience string, expirationTim
return tokenString, nil
}
// GenerateSessionID generates a unique session ID using UUIDv4.
func GenerateSessionID() (string, error) {
return util.GenUUID(), nil
}
// BuildSessionCookieValue builds the session cookie value in format {userID}-{sessionID}.
func BuildSessionCookieValue(userID int32, sessionID string) string {
return fmt.Sprintf("%d-%s", userID, sessionID)
}
// ParseSessionCookieValue parses the session cookie value to extract userID and sessionID.
func ParseSessionCookieValue(cookieValue string) (int32, string, error) {
parts := strings.SplitN(cookieValue, "-", 2)
if len(parts) != 2 {
return 0, "", fmt.Errorf("invalid session cookie format")
}
userID, err := util.ConvertStringToInt32(parts[0])
if err != nil {
return 0, "", fmt.Errorf("invalid user ID in session cookie: %v", err)
}
return userID, parts[1], nil
}

View file

@ -36,9 +36,9 @@ func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrent
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Set the cookie header to expire access token.
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
// Clear auth cookies
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err)
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
@ -178,6 +178,7 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe
}
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
// Generate JWT access token for API use
accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
if err != nil {
return status.Errorf(codes.Internal, "failed to generate access token, error: %v", err)
@ -186,19 +187,27 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim
return status.Errorf(codes.Internal, "failed to upsert access token to store, error: %v", err)
}
// Generate unique session ID for web use
sessionID, err := GenerateSessionID()
if err != nil {
return status.Errorf(codes.Internal, "failed to generate session ID, error: %v", err)
}
// Track session in user settings
if err := s.trackUserSession(ctx, user.ID, accessToken, expireTime); err != nil {
if err := s.trackUserSession(ctx, user.ID, sessionID, expireTime); err != nil {
// Log the error but don't fail the login if session tracking fails
// This ensures backward compatibility
slog.Error("failed to track user session", "error", err)
}
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
// Set session cookie for web use (format: userID-sessionID)
sessionCookieValue := BuildSessionCookieValue(user.ID, sessionID)
sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime)
if err != nil {
return status.Errorf(codes.Internal, "failed to build access token cookie, error: %v", err)
return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err)
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
"Set-Cookie": sessionCookie,
})); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
@ -281,28 +290,31 @@ func (s *APIV1Service) DeleteSession(ctx context.Context, _ *v1pb.DeleteSessionR
}
}
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
if err := s.clearAuthCookies(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
// Clear session cookie
sessionCookie, err := s.buildSessionCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build access token cookie")
return errors.Wrap(err, "failed to build session cookie")
}
// Set both cookies in the response
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
"Set-Cookie": sessionCookie,
})); err != nil {
return errors.Wrap(err, "failed to set grpc header")
}
return nil
}
func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
fmt.Sprintf("%s=%s", SessionCookieName, sessionCookieValue),
"Path=/",
"HttpOnly",
}
@ -364,23 +376,189 @@ func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessi
}
// Helper function to extract client information from the gRPC context.
func (*APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo {
// extractClientInfo extracts comprehensive client information from the request context.
// This includes user agent parsing to determine device type, operating system, browser,
// and IP address extraction. This information is used to provide detailed session
// tracking and management capabilities in the web UI.
//
// Fields populated:
// - UserAgent: Raw user agent string
// - IpAddress: Client IP (from X-Forwarded-For or X-Real-IP headers)
// - DeviceType: "mobile", "tablet", or "desktop"
// - Os: Operating system name and version (e.g., "iOS 17.1", "Windows 10/11")
// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0")
// - Country: Geographic location (TODO: implement with GeoIP service)
func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
// Extract user agent from metadata if available
if md, ok := metadata.FromIncomingContext(ctx); ok {
if userAgents := md.Get("user-agent"); len(userAgents) > 0 {
clientInfo.UserAgent = userAgents[0]
userAgent := userAgents[0]
clientInfo.UserAgent = userAgent
// Parse user agent to extract device type, OS, browser info
s.parseUserAgent(userAgent, clientInfo)
}
if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 {
clientInfo.IpAddress = forwardedFor[0]
ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple
ipAddress = strings.TrimSpace(ipAddress)
clientInfo.IpAddress = ipAddress
} else if realIP := md.Get("x-real-ip"); len(realIP) > 0 {
clientInfo.IpAddress = realIP[0]
}
}
// TODO: Parse user agent to extract device type, OS, browser info
// This could be done using a user agent parsing library
return clientInfo
}
// parseUserAgent extracts device type, OS, and browser information from user agent string
func (s *APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.SessionsUserSetting_ClientInfo) {
if userAgent == "" {
return
}
userAgent = strings.ToLower(userAgent)
// Detect device type
if strings.Contains(userAgent, "ipad") {
clientInfo.DeviceType = "tablet"
} else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") ||
strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") ||
strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") {
clientInfo.DeviceType = "mobile"
} else if strings.Contains(userAgent, "tablet") {
clientInfo.DeviceType = "tablet"
} else {
clientInfo.DeviceType = "desktop"
}
// Detect operating system
if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") {
// Extract iOS version
if idx := strings.Index(userAgent, "cpu os "); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1)
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else if idx := strings.Index(userAgent, "iphone os "); idx != -1 {
versionStart := idx + 10
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd != -1 {
version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1)
clientInfo.Os = "iOS " + version
} else {
clientInfo.Os = "iOS"
}
} else {
clientInfo.Os = "iOS"
}
} else if strings.Contains(userAgent, "android") {
// Extract Android version
if idx := strings.Index(userAgent, "android "); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Os = "Android " + version
} else {
clientInfo.Os = "Android"
}
} else {
clientInfo.Os = "Android"
}
} else if strings.Contains(userAgent, "windows nt 10.0") {
clientInfo.Os = "Windows 10/11"
} else if strings.Contains(userAgent, "windows nt 6.3") {
clientInfo.Os = "Windows 8.1"
} else if strings.Contains(userAgent, "windows nt 6.1") {
clientInfo.Os = "Windows 7"
} else if strings.Contains(userAgent, "windows") {
clientInfo.Os = "Windows"
} else if strings.Contains(userAgent, "mac os x") {
// Extract macOS version
if idx := strings.Index(userAgent, "mac os x "); idx != -1 {
versionStart := idx + 9
versionEnd := strings.Index(userAgent[versionStart:], ";")
if versionEnd == -1 {
versionEnd = strings.Index(userAgent[versionStart:], ")")
}
if versionEnd != -1 {
version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1)
clientInfo.Os = "macOS " + version
} else {
clientInfo.Os = "macOS"
}
} else {
clientInfo.Os = "macOS"
}
} else if strings.Contains(userAgent, "linux") {
clientInfo.Os = "Linux"
} else if strings.Contains(userAgent, "cros") {
clientInfo.Os = "Chrome OS"
}
// Detect browser
if strings.Contains(userAgent, "edg/") {
// Extract Edge version
if idx := strings.Index(userAgent, "edg/"); idx != -1 {
versionStart := idx + 4
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Edge " + version
} else {
clientInfo.Browser = "Edge"
}
} else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") {
// Extract Chrome version
if idx := strings.Index(userAgent, "chrome/"); idx != -1 {
versionStart := idx + 7
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Chrome " + version
} else {
clientInfo.Browser = "Chrome"
}
} else if strings.Contains(userAgent, "firefox/") {
// Extract Firefox version
if idx := strings.Index(userAgent, "firefox/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Firefox " + version
} else {
clientInfo.Browser = "Firefox"
}
} else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") {
// Extract Safari version
if idx := strings.Index(userAgent, "version/"); idx != -1 {
versionStart := idx + 8
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd == -1 {
versionEnd = len(userAgent) - versionStart
}
version := userAgent[versionStart : versionStart+versionEnd]
clientInfo.Browser = "Safari " + version
} else {
clientInfo.Browser = "Safari"
}
} else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") {
clientInfo.Browser = "Opera"
}
}

View file

@ -0,0 +1,179 @@
package v1
import (
"context"
"testing"
"google.golang.org/grpc/metadata"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestParseUserAgent(t *testing.T) {
service := &APIV1Service{}
tests := []struct {
name string
userAgent string
expectedDevice string
expectedOS string
expectedBrowser string
}{
{
name: "Chrome on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on macOS",
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15",
expectedDevice: "desktop",
expectedOS: "macOS 10.15.7",
expectedBrowser: "Safari 17.0",
},
{
name: "Chrome on Android Mobile",
userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36",
expectedDevice: "mobile",
expectedOS: "Android 13",
expectedBrowser: "Chrome 119.0.0.0",
},
{
name: "Safari on iPhone",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "mobile",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
{
name: "Firefox on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Firefox 119.0",
},
{
name: "Edge on Windows",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0",
expectedDevice: "desktop",
expectedOS: "Windows 10/11",
expectedBrowser: "Edge 119.0.0.0",
},
{
name: "iPad Safari",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1",
expectedDevice: "tablet",
expectedOS: "iOS 17.0",
expectedBrowser: "Safari 17.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
service.parseUserAgent(tt.userAgent, clientInfo)
if clientInfo.DeviceType != tt.expectedDevice {
t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType)
}
if clientInfo.Os != tt.expectedOS {
t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os)
}
if clientInfo.Browser != tt.expectedBrowser {
t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser)
}
})
}
}
func TestExtractClientInfo(t *testing.T) {
service := &APIV1Service{}
// Test with metadata containing user agent and IP
md := metadata.New(map[string]string{
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36",
"x-forwarded-for": "203.0.113.1, 198.51.100.1",
"x-real-ip": "203.0.113.1",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
clientInfo := service.extractClientInfo(ctx)
if clientInfo.UserAgent == "" {
t.Error("Expected user agent to be set")
}
if clientInfo.IpAddress != "203.0.113.1" {
t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress)
}
if clientInfo.DeviceType != "desktop" {
t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType)
}
if clientInfo.Os != "Windows 10/11" {
t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os)
}
if clientInfo.Browser != "Chrome 119.0.0.0" {
t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser)
}
}
// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents
func TestClientInfoExamples(t *testing.T) {
service := &APIV1Service{}
examples := []struct {
description string
userAgent string
}{
{
description: "Modern Chrome on Windows 11",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
},
{
description: "Safari on iPhone 15 Pro",
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
{
description: "Chrome on Samsung Galaxy",
userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36",
},
{
description: "Firefox on Ubuntu",
userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0",
},
{
description: "Edge on Windows 10",
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
},
{
description: "Safari on iPad Air",
userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1",
},
}
for _, example := range examples {
t.Run(example.description, func(t *testing.T) {
clientInfo := &storepb.SessionsUserSetting_ClientInfo{}
service.parseUserAgent(example.userAgent, clientInfo)
t.Logf("User Agent: %s", example.userAgent)
t.Logf("Device Type: %s", clientInfo.DeviceType)
t.Logf("Operating System: %s", clientInfo.Os)
t.Logf("Browser: %s", clientInfo.Browser)
t.Logf("---")
// Ensure all fields are populated
if clientInfo.DeviceType == "" {
t.Error("Device type should not be empty")
}
if clientInfo.Os == "" {
t.Error("OS should not be empty")
}
if clientInfo.Browser == "" {
t.Error("Browser should not be empty")
}
})
}
}

View file

@ -627,7 +627,6 @@ func (s *APIV1Service) ListUserSessions(ctx context.Context, request *v1pb.ListU
DeviceType: userSession.ClientInfo.DeviceType,
Os: userSession.ClientInfo.Os,
Browser: userSession.ClientInfo.Browser,
Country: userSession.ClientInfo.Country,
}
}

View file

@ -7,6 +7,7 @@ import showUpdateAccountDialog from "../UpdateAccountDialog";
import UserAvatar from "../UserAvatar";
import { Popover, PopoverContent, PopoverTrigger } from "../ui/Popover";
import AccessTokenSection from "./AccessTokenSection";
import UserSessionsSection from "./UserSessionsSection";
const MyAccountSection = () => {
const t = useTranslate();
@ -48,6 +49,7 @@ const MyAccountSection = () => {
</div>
<AccessTokenSection />
<UserSessionsSection />
</div>
);
};

View file

@ -0,0 +1,177 @@
import { Button } from "@usememos/mui";
import { ClockIcon, MonitorIcon, SmartphoneIcon, TabletIcon, TrashIcon, WifiIcon } from "lucide-react";
import { useEffect, useState } from "react";
import { toast } from "react-hot-toast";
import { userServiceClient } from "@/grpcweb";
import useCurrentUser from "@/hooks/useCurrentUser";
import { UserSession } from "@/types/proto/api/v1/user_service";
import { useTranslate } from "@/utils/i18n";
import LearnMore from "../LearnMore";
const listUserSessions = async (parent: string) => {
const { sessions } = await userServiceClient.listUserSessions({ parent });
return sessions.sort((a, b) => (b.lastAccessedTime?.getTime() ?? 0) - (a.lastAccessedTime?.getTime() ?? 0));
};
const UserSessionsSection = () => {
const t = useTranslate();
const currentUser = useCurrentUser();
const [userSessions, setUserSessions] = useState<UserSession[]>([]);
useEffect(() => {
listUserSessions(currentUser.name).then((sessions) => {
setUserSessions(sessions);
});
}, []);
const handleRevokeSession = async (userSession: UserSession) => {
const formattedSessionId = getFormattedSessionId(userSession.sessionId);
const confirmed = window.confirm(t("setting.user-sessions-section.session-revocation", { sessionId: formattedSessionId }));
if (confirmed) {
await userServiceClient.revokeUserSession({ name: userSession.name });
setUserSessions(userSessions.filter((session) => session.sessionId !== userSession.sessionId));
toast.success(t("setting.user-sessions-section.session-revoked"));
}
};
const getFormattedSessionId = (sessionId: string) => {
return `${sessionId.slice(0, 8)}...${sessionId.slice(-8)}`;
};
const getDeviceIcon = (deviceType: string) => {
switch (deviceType?.toLowerCase()) {
case "mobile":
return <SmartphoneIcon className="w-4 h-4 text-gray-500" />;
case "tablet":
return <TabletIcon className="w-4 h-4 text-gray-500" />;
case "desktop":
default:
return <MonitorIcon className="w-4 h-4 text-gray-500" />;
}
};
const formatLocation = (clientInfo: UserSession["clientInfo"]) => {
if (!clientInfo) return "Unknown";
const parts = [];
if (clientInfo.ipAddress) parts.push(clientInfo.ipAddress);
return parts.length > 0 ? parts.join(" • ") : "Unknown";
};
const formatDeviceInfo = (clientInfo: UserSession["clientInfo"]) => {
if (!clientInfo) return "Unknown Device";
const parts = [];
if (clientInfo.os) parts.push(clientInfo.os);
if (clientInfo.browser) parts.push(clientInfo.browser);
return parts.length > 0 ? parts.join(" • ") : "Unknown Device";
};
const isCurrentSession = (session: UserSession) => {
// A simple heuristic: the most recently accessed session is likely the current one
if (userSessions.length === 0) return false;
const mostRecent = userSessions[0];
return session.sessionId === mostRecent.sessionId;
};
return (
<div className="mt-6 w-full flex flex-col justify-start items-start space-y-4">
<div className="w-full">
<div className="sm:flex sm:items-center sm:justify-between">
<div className="sm:flex-auto space-y-1">
<p className="flex flex-row justify-start items-center font-medium text-gray-700 dark:text-gray-400">
{t("setting.user-sessions-section.title")}
<LearnMore className="ml-2" url="https://usememos.com/docs/security/sessions" />
</p>
<p className="text-sm text-gray-700 dark:text-gray-500">{t("setting.user-sessions-section.description")}</p>
</div>
</div>
<div className="w-full mt-2 flow-root">
<div className="overflow-x-auto">
<div className="inline-block min-w-full border border-zinc-200 rounded-lg align-middle dark:border-zinc-600">
<table className="min-w-full divide-y divide-gray-300 dark:divide-zinc-600">
<thead>
<tr>
<th scope="col" className="px-3 py-2 text-left text-sm font-semibold text-gray-900 dark:text-gray-400">
{t("setting.user-sessions-section.device")}
</th>
<th scope="col" className="py-2 pl-4 pr-3 text-left text-sm font-semibold text-gray-900 dark:text-gray-400">
{t("setting.user-sessions-section.location")}
</th>
<th scope="col" className="px-3 py-2 text-left text-sm font-semibold text-gray-900 dark:text-gray-400">
{t("setting.user-sessions-section.last-active")}
</th>
<th scope="col" className="px-3 py-2 text-left text-sm font-semibold text-gray-900 dark:text-gray-400">
{t("setting.user-sessions-section.expires")}
</th>
<th scope="col" className="relative py-3.5 pl-3 pr-4">
<span className="sr-only">{t("common.delete")}</span>
</th>
</tr>
</thead>
<tbody className="divide-y divide-gray-200 dark:divide-zinc-700">
{userSessions.map((userSession) => (
<tr key={userSession.sessionId}>
<td className="whitespace-nowrap px-3 py-2 text-sm text-gray-900 dark:text-gray-400">
<div className="flex items-center space-x-3">
{getDeviceIcon(userSession.clientInfo?.deviceType || "")}
<div className="flex flex-col">
<span className="font-medium">
{formatDeviceInfo(userSession.clientInfo)}
{isCurrentSession(userSession) && (
<span className="ml-2 inline-flex items-center px-2 py-1 rounded-full text-xs font-medium bg-green-100 text-green-800 dark:bg-green-800 dark:text-green-100">
<WifiIcon className="w-3 h-3 mr-1" />
{t("setting.user-sessions-section.current")}
</span>
)}
</span>
<span className="text-xs text-gray-500 font-mono">{getFormattedSessionId(userSession.sessionId)}</span>
</div>
</div>
</td>
<td className="whitespace-nowrap py-2 pl-4 pr-3 text-sm text-gray-900 dark:text-gray-400">
{formatLocation(userSession.clientInfo)}
</td>
<td className="whitespace-nowrap px-3 py-2 text-sm text-gray-500 dark:text-gray-400">
<div className="flex items-center space-x-1">
<ClockIcon className="w-4 h-4" />
<span>{userSession.lastAccessedTime?.toLocaleString()}</span>
</div>
</td>
<td className="whitespace-nowrap px-3 py-2 text-sm text-gray-500 dark:text-gray-400">
{userSession.expireTime?.toLocaleString() ?? t("setting.user-sessions-section.never")}
</td>
<td className="relative whitespace-nowrap py-2 pl-3 pr-4 text-right text-sm">
<Button
variant="plain"
disabled={isCurrentSession(userSession)}
onClick={() => {
handleRevokeSession(userSession);
}}
title={
isCurrentSession(userSession)
? t("setting.user-sessions-section.cannot-revoke-current")
: t("setting.user-sessions-section.revoke-session")
}
>
<TrashIcon className={`w-4 h-auto ${isCurrentSession(userSession) ? "text-gray-400" : "text-red-600"}`} />
</Button>
</td>
</tr>
))}
</tbody>
</table>
{userSessions.length === 0 && (
<div className="text-center py-8 text-gray-500 dark:text-gray-400">{t("setting.user-sessions-section.no-sessions")}</div>
)}
</div>
</div>
</div>
</div>
</div>
);
};
export default UserSessionsSection;

View file

@ -251,6 +251,21 @@
"title": "Access Tokens",
"token": "Token"
},
"user-sessions-section": {
"title": "Active Sessions",
"description": "A list of all active sessions for your account. You can revoke any session except the current one.",
"device": "Device",
"location": "Location",
"last-active": "Last Active",
"expires": "Expires",
"current": "Current",
"never": "Never",
"session-revocation": "Are you sure to revoke session {{sessionId}}? You will need to sign in again on that device.",
"session-revoked": "Session revoked successfully",
"revoke-session": "Revoke session",
"cannot-revoke-current": "Cannot revoke current session",
"no-sessions": "No active sessions found"
},
"account-section": {
"change-password": "Change password",
"email-note": "Optional",

View file

@ -394,8 +394,6 @@ export interface UserSession_ClientInfo {
os: string;
/** Optional. Browser name and version (e.g., "Chrome 119.0"). */
browser: string;
/** Optional. Geographic location (country code, e.g., "US"). */
country: string;
}
export interface ListUserSessionsRequest {
@ -2222,7 +2220,7 @@ export const UserSession: MessageFns<UserSession> = {
};
function createBaseUserSession_ClientInfo(): UserSession_ClientInfo {
return { userAgent: "", ipAddress: "", deviceType: "", os: "", browser: "", country: "" };
return { userAgent: "", ipAddress: "", deviceType: "", os: "", browser: "" };
}
export const UserSession_ClientInfo: MessageFns<UserSession_ClientInfo> = {
@ -2242,9 +2240,6 @@ export const UserSession_ClientInfo: MessageFns<UserSession_ClientInfo> = {
if (message.browser !== "") {
writer.uint32(42).string(message.browser);
}
if (message.country !== "") {
writer.uint32(50).string(message.country);
}
return writer;
},
@ -2295,14 +2290,6 @@ export const UserSession_ClientInfo: MessageFns<UserSession_ClientInfo> = {
message.browser = reader.string();
continue;
}
case 6: {
if (tag !== 50) {
break;
}
message.country = reader.string();
continue;
}
}
if ((tag & 7) === 4 || tag === 0) {
break;
@ -2322,7 +2309,6 @@ export const UserSession_ClientInfo: MessageFns<UserSession_ClientInfo> = {
message.deviceType = object.deviceType ?? "";
message.os = object.os ?? "";
message.browser = object.browser ?? "";
message.country = object.country ?? "";
return message;
},
};