mirror of
				https://github.com/usememos/memos.git
				synced 2025-10-31 16:59:30 +08:00 
			
		
		
		
	feat: persistent session name (#902)
* feat: persistent session name * chore: update
This commit is contained in:
		
							parent
							
								
									92a8a4ac0c
								
							
						
					
					
						commit
						d50ad9433f
					
				
					 7 changed files with 96 additions and 59 deletions
				
			
		|  | @ -13,6 +13,8 @@ type SystemSettingName string | |||
| const ( | ||||
| 	// SystemSettingServerID is the key type of server id. | ||||
| 	SystemSettingServerID SystemSettingName = "serverId" | ||||
| 	// SystemSettingSecretSessionName is the key type of secret session name. | ||||
| 	SystemSettingSecretSessionName SystemSettingName = "secretSessionName" | ||||
| 	// SystemSettingAllowSignUpName is the key type of allow signup setting. | ||||
| 	SystemSettingAllowSignUpName SystemSettingName = "allowSignUp" | ||||
| 	// SystemSettingAdditionalStyleName is the key type of additional style. | ||||
|  | @ -43,6 +45,8 @@ func (key SystemSettingName) String() string { | |||
| 	switch key { | ||||
| 	case SystemSettingServerID: | ||||
| 		return "serverId" | ||||
| 	case SystemSettingSecretSessionName: | ||||
| 		return "secretSessionName" | ||||
| 	case SystemSettingAllowSignUpName: | ||||
| 		return "allowSignUp" | ||||
| 	case SystemSettingAdditionalStyleName: | ||||
|  |  | |||
|  | @ -4,15 +4,13 @@ import ( | |||
| 	"os" | ||||
| 
 | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| 	"github.com/pkg/errors" | ||||
| 
 | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/usememos/memos/server" | ||||
| 	"github.com/usememos/memos/server/profile" | ||||
| 	"github.com/usememos/memos/store" | ||||
| 
 | ||||
| 	DB "github.com/usememos/memos/store/db" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -40,20 +38,11 @@ func run() error { | |||
| 	println("version:", profile.Version) | ||||
| 	println("---") | ||||
| 
 | ||||
| 	db := DB.NewDB(profile) | ||||
| 	if err := db.Open(ctx); err != nil { | ||||
| 		return fmt.Errorf("cannot open db: %w", err) | ||||
| 	serverInstance, err := server.NewServer(ctx, profile) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "failed to start server") | ||||
| 	} | ||||
| 
 | ||||
| 	serverInstance := server.NewServer(profile) | ||||
| 	storeInstance := store.New(db.Db, profile) | ||||
| 	serverInstance.Store = storeInstance | ||||
| 
 | ||||
| 	metricCollector := server.NewMetricCollector(profile, storeInstance) | ||||
| 	// Disable metrics collector. | ||||
| 	metricCollector.Enabled = false | ||||
| 	serverInstance.Collector = &metricCollector | ||||
| 
 | ||||
| 	println(greetingBanner) | ||||
| 	fmt.Printf("Version %s has started at :%d\n", profile.Version, profile.Port) | ||||
| 	return serverInstance.Run(ctx) | ||||
|  |  | |||
							
								
								
									
										2
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -16,7 +16,7 @@ require github.com/labstack/echo/v4 v4.9.0 | |||
| require ( | ||||
| 	github.com/VictoriaMetrics/fastcache v1.10.0 | ||||
| 	github.com/gorilla/feeds v1.1.1 | ||||
| 	github.com/gorilla/securecookie v1.1.1 | ||||
| 	github.com/gorilla/securecookie v1.1.1 // indirect | ||||
| 	github.com/gorilla/sessions v1.2.1 | ||||
| 	github.com/labstack/echo-contrib v0.13.0 | ||||
| 	github.com/stretchr/testify v1.8.1 | ||||
|  |  | |||
|  | @ -6,14 +6,12 @@ import ( | |||
| 	"fmt" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/usememos/memos/api" | ||||
| 	"github.com/usememos/memos/common" | ||||
| 	"github.com/usememos/memos/server/profile" | ||||
| 	"github.com/usememos/memos/store" | ||||
| 	"github.com/usememos/memos/store/db" | ||||
| 
 | ||||
| 	"github.com/gorilla/securecookie" | ||||
| 	"github.com/gorilla/sessions" | ||||
| 	"github.com/labstack/echo-contrib/session" | ||||
| 	"github.com/labstack/echo/v4" | ||||
|  | @ -23,16 +21,13 @@ import ( | |||
| type Server struct { | ||||
| 	e *echo.Echo | ||||
| 
 | ||||
| 	ID string | ||||
| 
 | ||||
| 	ID        string | ||||
| 	Profile   *profile.Profile | ||||
| 	Store     *store.Store | ||||
| 	Collector *MetricCollector | ||||
| 
 | ||||
| 	Profile *profile.Profile | ||||
| 
 | ||||
| 	Store *store.Store | ||||
| } | ||||
| 
 | ||||
| func NewServer(profile *profile.Profile) *Server { | ||||
| func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { | ||||
| 	e := echo.New() | ||||
| 	e.Debug = true | ||||
| 	e.HideBanner = true | ||||
|  | @ -43,6 +38,19 @@ func NewServer(profile *profile.Profile) *Server { | |||
| 		Profile: profile, | ||||
| 	} | ||||
| 
 | ||||
| 	db := db.NewDB(profile) | ||||
| 	if err := db.Open(ctx); err != nil { | ||||
| 		return nil, errors.Wrap(err, "cannot open db") | ||||
| 	} | ||||
| 
 | ||||
| 	storeInstance := store.New(db.DBInstance, profile) | ||||
| 	s.Store = storeInstance | ||||
| 
 | ||||
| 	metricCollector := NewMetricCollector(profile, storeInstance) | ||||
| 	// Disable metrics collector. | ||||
| 	metricCollector.Enabled = false | ||||
| 	s.Collector = &metricCollector | ||||
| 
 | ||||
| 	e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ | ||||
| 		Format: `{"time":"${time_rfc3339}",` + | ||||
| 			`"method":"${method}","uri":"${uri}",` + | ||||
|  | @ -68,14 +76,22 @@ func NewServer(profile *profile.Profile) *Server { | |||
| 		Timeout:      30 * time.Second, | ||||
| 	})) | ||||
| 
 | ||||
| 	embedFrontend(e) | ||||
| 
 | ||||
| 	// In dev mode, set the const secret key to make signin session persistence. | ||||
| 	secret := []byte("usememos") | ||||
| 	if profile.Mode == "prod" { | ||||
| 		secret = securecookie.GenerateRandomKey(16) | ||||
| 	serverID, err := s.getSystemServerID(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	e.Use(session.Middleware(sessions.NewCookieStore(secret))) | ||||
| 	s.ID = serverID | ||||
| 
 | ||||
| 	secretSessionName := "usememos" | ||||
| 	if profile.Mode == "prod" { | ||||
| 		secretSessionName, err = s.getSystemSecretSessionName(ctx) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 	e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName)))) | ||||
| 
 | ||||
| 	embedFrontend(e) | ||||
| 
 | ||||
| 	rootGroup := e.Group("") | ||||
| 	s.registerRSSRoutes(rootGroup) | ||||
|  | @ -99,28 +115,10 @@ func NewServer(profile *profile.Profile) *Server { | |||
| 	s.registerResourceRoutes(apiGroup) | ||||
| 	s.registerTagRoutes(apiGroup) | ||||
| 
 | ||||
| 	return s | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) Run(ctx context.Context) error { | ||||
| 	serverIDKey := api.SystemSettingServerID | ||||
| 	serverIDValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ | ||||
| 		Name: &serverIDKey, | ||||
| 	}) | ||||
| 	if err != nil && common.ErrorCode(err) != common.NotFound { | ||||
| 		return err | ||||
| 	} | ||||
| 	if serverIDValue == nil || serverIDValue.Value == "" { | ||||
| 		serverIDValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ | ||||
| 			Name:  serverIDKey, | ||||
| 			Value: uuid.NewString(), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	s.ID = serverIDValue.Value | ||||
| 
 | ||||
| 	if err := s.createServerStartActivity(ctx); err != nil { | ||||
| 		return errors.Wrap(err, "failed to create activity") | ||||
| 	} | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/usememos/memos/api" | ||||
| 	"github.com/usememos/memos/common" | ||||
| 
 | ||||
|  | @ -61,6 +63,10 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { | |||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err) | ||||
| 		} | ||||
| 		for _, systemSetting := range systemSettingList { | ||||
| 			if systemSetting.Name == api.SystemSettingServerID || systemSetting.Name == api.SystemSettingSecretSessionName { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 			var value interface{} | ||||
| 			err := json.Unmarshal([]byte(systemSetting.Value), &value) | ||||
| 			if err != nil { | ||||
|  | @ -195,3 +201,43 @@ func (s *Server) registerSystemRoutes(g *echo.Group) { | |||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) getSystemServerID(ctx context.Context) (string, error) { | ||||
| 	serverIDKey := api.SystemSettingServerID | ||||
| 	serverIDValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ | ||||
| 		Name: &serverIDKey, | ||||
| 	}) | ||||
| 	if err != nil && common.ErrorCode(err) != common.NotFound { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if serverIDValue == nil || serverIDValue.Value == "" { | ||||
| 		serverIDValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ | ||||
| 			Name:  serverIDKey, | ||||
| 			Value: uuid.NewString(), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	return serverIDValue.Value, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) { | ||||
| 	secretSessionNameKey := api.SystemSettingSecretSessionName | ||||
| 	secretSessionNameValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ | ||||
| 		Name: &secretSessionNameKey, | ||||
| 	}) | ||||
| 	if err != nil && common.ErrorCode(err) != common.NotFound { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if secretSessionNameValue == nil || secretSessionNameValue.Value == "" { | ||||
| 		secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ | ||||
| 			Name:  secretSessionNameKey, | ||||
| 			Value: uuid.NewString(), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	return secretSessionNameValue.Value, nil | ||||
| } | ||||
|  |  | |||
|  | @ -24,8 +24,8 @@ var seedFS embed.FS | |||
| 
 | ||||
| type DB struct { | ||||
| 	// sqlite db connection instance | ||||
| 	Db      *sql.DB | ||||
| 	profile *profile.Profile | ||||
| 	DBInstance *sql.DB | ||||
| 	profile    *profile.Profile | ||||
| } | ||||
| 
 | ||||
| // NewDB returns a new instance of DB associated with the given datasource name. | ||||
|  | @ -47,7 +47,7 @@ func (db *DB) Open(ctx context.Context) (err error) { | |||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to open db with dsn: %s, err: %w", db.profile.DSN, err) | ||||
| 	} | ||||
| 	db.Db = sqliteDB | ||||
| 	db.DBInstance = sqliteDB | ||||
| 
 | ||||
| 	if db.profile.Mode == "dev" { | ||||
| 		// In dev mode, we should migrate and seed the database. | ||||
|  | @ -156,7 +156,7 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	tx, err := db.Db.Begin() | ||||
| 	tx, err := db.DBInstance.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -197,7 +197,7 @@ func (db *DB) seed(ctx context.Context) error { | |||
| 
 | ||||
| // execute runs a single SQL statement within a transaction. | ||||
| func (db *DB) execute(ctx context.Context, stmt string) error { | ||||
| 	tx, err := db.Db.Begin() | ||||
| 	tx, err := db.DBInstance.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  |  | |||
|  | @ -20,7 +20,7 @@ type MigrationHistoryFind struct { | |||
| } | ||||
| 
 | ||||
| func (db *DB) FindMigrationHistory(ctx context.Context, find *MigrationHistoryFind) (*MigrationHistory, error) { | ||||
| 	tx, err := db.Db.BeginTx(ctx, nil) | ||||
| 	tx, err := db.DBInstance.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -40,7 +40,7 @@ func (db *DB) FindMigrationHistory(ctx context.Context, find *MigrationHistoryFi | |||
| } | ||||
| 
 | ||||
| func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { | ||||
| 	tx, err := db.Db.BeginTx(ctx, nil) | ||||
| 	tx, err := db.DBInstance.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue