mirror of
				https://github.com/usememos/memos.git
				synced 2025-10-31 08:46:39 +08:00 
			
		
		
		
	
							parent
							
								
									e62a94c05a
								
							
						
					
					
						commit
						994d5dd891
					
				
					 7 changed files with 442 additions and 7 deletions
				
			
		|  | @ -191,7 +191,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha | |||
| 		if generateToken { | ||||
| 			generateTokenFunc := func() error { | ||||
| 				rc, err := c.Cookie(auth.RefreshTokenCookieName) | ||||
| 
 | ||||
| 				if err != nil { | ||||
| 					return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") | ||||
| 				} | ||||
|  |  | |||
|  | @ -53,11 +53,6 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { | |||
| 
 | ||||
| 	e.Use(middleware.Gzip()) | ||||
| 
 | ||||
| 	e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ | ||||
| 		Skipper:     s.defaultAuthSkipper, | ||||
| 		TokenLookup: "cookie:_csrf", | ||||
| 	})) | ||||
| 
 | ||||
| 	e.Use(middleware.CORS()) | ||||
| 
 | ||||
| 	e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ | ||||
|  | @ -141,6 +136,10 @@ func (s *Server) Shutdown(ctx context.Context) { | |||
| 	fmt.Printf("memos stopped properly\n") | ||||
| } | ||||
| 
 | ||||
| func (s *Server) GetEcho() *echo.Echo { | ||||
| 	return s.e | ||||
| } | ||||
| 
 | ||||
| func (s *Server) createServerStartActivity(ctx context.Context) error { | ||||
| 	payload := api.ActivityServerStartPayload{ | ||||
| 		ServerID: s.ID, | ||||
|  |  | |||
							
								
								
									
										54
									
								
								test/server/auth_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								test/server/auth_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,54 @@ | |||
| package testserver | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/usememos/memos/api" | ||||
| ) | ||||
| 
 | ||||
| func TestAuthServer(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	s, err := NewTestingServer(ctx, t) | ||||
| 	require.NoError(t, err) | ||||
| 	defer s.Shutdown(ctx) | ||||
| 
 | ||||
| 	signup := &api.SignUp{ | ||||
| 		Username: "testuser", | ||||
| 		Password: "testpassword", | ||||
| 	} | ||||
| 	user, err := s.postAuthSignup(signup) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, signup.Username, user.Username) | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) postAuthSignup(signup *api.SignUp) (*api.User, error) { | ||||
| 	rawData, err := json.Marshal(&signup) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to marshal signup") | ||||
| 	} | ||||
| 	reader := bytes.NewReader(rawData) | ||||
| 	body, err := s.post("/api/auth/signup", reader, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	_, err = buf.ReadFrom(body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to read response body") | ||||
| 	} | ||||
| 
 | ||||
| 	type AuthSignupResponse struct { | ||||
| 		Data *api.User `json:"data"` | ||||
| 	} | ||||
| 	res := new(AuthSignupResponse) | ||||
| 	if err = json.Unmarshal(buf.Bytes(), res); err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to unmarshal post signup response") | ||||
| 	} | ||||
| 	return res.Data, nil | ||||
| } | ||||
							
								
								
									
										134
									
								
								test/server/memo_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								test/server/memo_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,134 @@ | |||
| package testserver | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/usememos/memos/api" | ||||
| ) | ||||
| 
 | ||||
| func TestMemoServer(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	s, err := NewTestingServer(ctx, t) | ||||
| 	require.NoError(t, err) | ||||
| 	defer s.Shutdown(ctx) | ||||
| 
 | ||||
| 	signup := &api.SignUp{ | ||||
| 		Username: "testuser", | ||||
| 		Password: "testpassword", | ||||
| 	} | ||||
| 	user, err := s.postAuthSignup(signup) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, signup.Username, user.Username) | ||||
| 	memoList, err := s.getMemoList() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, memoList, 0) | ||||
| 	memo, err := s.postMemoCreate(&api.MemoCreate{ | ||||
| 		Content: "test memo", | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "test memo", memo.Content) | ||||
| 	memoList, err = s.getMemoList() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, memoList, 1) | ||||
| 	updatedContent := "updated memo" | ||||
| 	memo, err = s.patchMemoPatch(&api.MemoPatch{ | ||||
| 		ID:      memo.ID, | ||||
| 		Content: &updatedContent, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, updatedContent, memo.Content) | ||||
| 	err = s.postMemoDelete(&api.MemoDelete{ | ||||
| 		ID: memo.ID, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
| 	memoList, err = s.getMemoList() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, memoList, 0) | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) getMemoList() ([]*api.Memo, error) { | ||||
| 	body, err := s.get("/api/memo", nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	_, err = buf.ReadFrom(body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to read response body") | ||||
| 	} | ||||
| 
 | ||||
| 	type MemoCreateResponse struct { | ||||
| 		Data []*api.Memo `json:"data"` | ||||
| 	} | ||||
| 	res := new(MemoCreateResponse) | ||||
| 	if err = json.Unmarshal(buf.Bytes(), res); err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to unmarshal get memo list response") | ||||
| 	} | ||||
| 	return res.Data, nil | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) postMemoCreate(memoCreate *api.MemoCreate) (*api.Memo, error) { | ||||
| 	rawData, err := json.Marshal(&memoCreate) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to marshal memo create") | ||||
| 	} | ||||
| 	reader := bytes.NewReader(rawData) | ||||
| 	body, err := s.post("/api/memo", reader, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	_, err = buf.ReadFrom(body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to read response body") | ||||
| 	} | ||||
| 
 | ||||
| 	type MemoCreateResponse struct { | ||||
| 		Data *api.Memo `json:"data"` | ||||
| 	} | ||||
| 	res := new(MemoCreateResponse) | ||||
| 	if err = json.Unmarshal(buf.Bytes(), res); err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to unmarshal post memo create response") | ||||
| 	} | ||||
| 	return res.Data, nil | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) patchMemoPatch(memoPatch *api.MemoPatch) (*api.Memo, error) { | ||||
| 	rawData, err := json.Marshal(&memoPatch) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to marshal memo patch") | ||||
| 	} | ||||
| 	reader := bytes.NewReader(rawData) | ||||
| 	body, err := s.patch(fmt.Sprintf("/api/memo/%d", memoPatch.ID), reader, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	_, err = buf.ReadFrom(body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to read response body") | ||||
| 	} | ||||
| 
 | ||||
| 	type MemoPatchResponse struct { | ||||
| 		Data *api.Memo `json:"data"` | ||||
| 	} | ||||
| 	res := new(MemoPatchResponse) | ||||
| 	if err = json.Unmarshal(buf.Bytes(), res); err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to unmarshal patch memo response") | ||||
| 	} | ||||
| 	return res.Data, nil | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) postMemoDelete(memoDelete *api.MemoDelete) error { | ||||
| 	_, err := s.delete(fmt.Sprintf("/api/memo/%d", memoDelete.ID), nil) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										176
									
								
								test/server/server.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								test/server/server.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,176 @@ | |||
| package testserver | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/usememos/memos/server" | ||||
| 	"github.com/usememos/memos/server/profile" | ||||
| 	"github.com/usememos/memos/store/db" | ||||
| 	"github.com/usememos/memos/test" | ||||
| 
 | ||||
| 	// sqlite3 driver. | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
| 
 | ||||
| type TestingServer struct { | ||||
| 	server  *server.Server | ||||
| 	client  *http.Client | ||||
| 	profile *profile.Profile | ||||
| 	cookie  string | ||||
| } | ||||
| 
 | ||||
| func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) { | ||||
| 	profile := test.GetTestingProfile(t) | ||||
| 	db := db.NewDB(profile) | ||||
| 	if err := db.Open(ctx); err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to open db") | ||||
| 	} | ||||
| 
 | ||||
| 	server, err := server.NewServer(ctx, profile) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to create server") | ||||
| 	} | ||||
| 
 | ||||
| 	errChan := make(chan error, 1) | ||||
| 
 | ||||
| 	s := &TestingServer{ | ||||
| 		server:  server, | ||||
| 		client:  &http.Client{}, | ||||
| 		profile: profile, | ||||
| 		cookie:  "", | ||||
| 	} | ||||
| 
 | ||||
| 	go func() { | ||||
| 		if err := s.server.Start(ctx); err != nil { | ||||
| 			if err != http.ErrServerClosed { | ||||
| 				errChan <- errors.Wrap(err, "failed to run main server") | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	if err := s.waitForServerStart(errChan); err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to start server") | ||||
| 	} | ||||
| 
 | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) Shutdown(ctx context.Context) { | ||||
| 	s.server.Shutdown(ctx) | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) waitForServerStart(errChan <-chan error) error { | ||||
| 	ticker := time.NewTicker(100 * time.Millisecond) | ||||
| 	defer ticker.Stop() | ||||
| 
 | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ticker.C: | ||||
| 			if s == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			e := s.server.GetEcho() | ||||
| 			if e == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			addr := e.ListenerAddr() | ||||
| 			if addr != nil && strings.Contains(addr.String(), ":") { | ||||
| 				return nil // was started | ||||
| 			} | ||||
| 		case err := <-errChan: | ||||
| 			if err == http.ErrServerClosed { | ||||
| 				return nil | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) { | ||||
| 	fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri) | ||||
| 	req, err := http.NewRequest(method, fullURL, body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL) | ||||
| 	} | ||||
| 
 | ||||
| 	for k, v := range header { | ||||
| 		req.Header.Set(k, v) | ||||
| 	} | ||||
| 
 | ||||
| 	q := url.Values{} | ||||
| 	for k, v := range params { | ||||
| 		q.Add(k, v) | ||||
| 	} | ||||
| 	if len(q) > 0 { | ||||
| 		req.URL.RawQuery = q.Encode() | ||||
| 	} | ||||
| 
 | ||||
| 	resp, err := s.client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL) | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, err := io.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "failed to read http response body") | ||||
| 		} | ||||
| 		return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body)) | ||||
| 	} | ||||
| 
 | ||||
| 	if method == "POST" { | ||||
| 		if strings.Contains(uri, "/api/auth/login") || strings.Contains(uri, "/api/auth/signup") { | ||||
| 			cookie := "" | ||||
| 			h := resp.Header.Get("Set-Cookie") | ||||
| 			parts := strings.Split(h, "; ") | ||||
| 			for _, p := range parts { | ||||
| 				if strings.HasPrefix(p, "access-token=") { | ||||
| 					cookie = p | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 			if cookie == "" { | ||||
| 				return nil, errors.Errorf("unable to find access token in the login response headers") | ||||
| 			} | ||||
| 			s.cookie = cookie | ||||
| 		} else if strings.Contains(uri, "/api/auth/logout") { | ||||
| 			s.cookie = "" | ||||
| 		} | ||||
| 	} | ||||
| 	return resp.Body, nil | ||||
| } | ||||
| 
 | ||||
| // get sends a GET client request. | ||||
| func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) { | ||||
| 	return s.request("GET", url, nil, params, map[string]string{ | ||||
| 		"Cookie": s.cookie, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // post sends a POST client request. | ||||
| func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) { | ||||
| 	return s.request("POST", url, body, params, map[string]string{ | ||||
| 		"Cookie": s.cookie, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // patch sends a PATCH client request. | ||||
| func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) { | ||||
| 	return s.request("PATCH", url, body, params, map[string]string{ | ||||
| 		"Cookie": s.cookie, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // delete sends a DELETE client request. | ||||
| func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) { | ||||
| 	return s.request("DELETE", url, nil, params, map[string]string{ | ||||
| 		"Cookie": s.cookie, | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										58
									
								
								test/server/system_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								test/server/system_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,58 @@ | |||
| package testserver | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/usememos/memos/api" | ||||
| ) | ||||
| 
 | ||||
| func TestSystemServer(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	s, err := NewTestingServer(ctx, t) | ||||
| 	require.NoError(t, err) | ||||
| 	defer s.Shutdown(ctx) | ||||
| 
 | ||||
| 	status, err := s.getSystemStatus() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, (*api.User)(nil), status.Host) | ||||
| 
 | ||||
| 	signup := &api.SignUp{ | ||||
| 		Username: "testuser", | ||||
| 		Password: "testpassword", | ||||
| 	} | ||||
| 	user, err := s.postAuthSignup(signup) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, signup.Username, user.Username) | ||||
| 
 | ||||
| 	status, err = s.getSystemStatus() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, user.ID, status.Host.ID) | ||||
| 	require.Equal(t, user.Username, status.Host.Username) | ||||
| } | ||||
| 
 | ||||
| func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) { | ||||
| 	body, err := s.get("/api/status", nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	buf := &bytes.Buffer{} | ||||
| 	_, err = buf.ReadFrom(body) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to read response body") | ||||
| 	} | ||||
| 
 | ||||
| 	type SystemStatusResponse struct { | ||||
| 		Data *api.SystemStatus `json:"data"` | ||||
| 	} | ||||
| 	res := new(SystemStatusResponse) | ||||
| 	if err = json.Unmarshal(buf.Bytes(), res); err != nil { | ||||
| 		return nil, errors.Wrap(err, "fail to unmarshal get system status response") | ||||
| 	} | ||||
| 	return res.Data, nil | ||||
| } | ||||
							
								
								
									
										17
									
								
								test/test.go
									
										
									
									
									
								
							
							
						
						
									
										17
									
								
								test/test.go
									
										
									
									
									
								
							|  | @ -2,19 +2,34 @@ package test | |||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/usememos/memos/server/profile" | ||||
| 	"github.com/usememos/memos/server/version" | ||||
| ) | ||||
| 
 | ||||
| func getUnusedPort() int { | ||||
| 	// Get a random unused port | ||||
| 	listener, err := net.Listen("tcp", "localhost:0") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	defer listener.Close() | ||||
| 
 | ||||
| 	// Get the port number | ||||
| 	port := listener.Addr().(*net.TCPAddr).Port | ||||
| 	return port | ||||
| } | ||||
| 
 | ||||
| func GetTestingProfile(t *testing.T) *profile.Profile { | ||||
| 	// Get a temporary directory for the test data. | ||||
| 	dir := t.TempDir() | ||||
| 	mode := "prod" | ||||
| 	port := getUnusedPort() | ||||
| 	return &profile.Profile{ | ||||
| 		Mode:    mode, | ||||
| 		Port:    8082, | ||||
| 		Port:    port, | ||||
| 		Data:    dir, | ||||
| 		DSN:     fmt.Sprintf("%s/memos_%s.db", dir, mode), | ||||
| 		Version: version.GetCurrentVersion(mode), | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue