diff --git a/server/jwt.go b/server/jwt.go index abe99930..f8f5317d 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -191,7 +191,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha if generateToken { generateTokenFunc := func() error { rc, err := c.Cookie(auth.RefreshTokenCookieName) - if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") } diff --git a/server/server.go b/server/server.go index 34cd56a3..75a24237 100644 --- a/server/server.go +++ b/server/server.go @@ -53,11 +53,6 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { e.Use(middleware.Gzip()) - e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ - Skipper: s.defaultAuthSkipper, - TokenLookup: "cookie:_csrf", - })) - e.Use(middleware.CORS()) e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ @@ -141,6 +136,10 @@ func (s *Server) Shutdown(ctx context.Context) { fmt.Printf("memos stopped properly\n") } +func (s *Server) GetEcho() *echo.Echo { + return s.e +} + func (s *Server) createServerStartActivity(ctx context.Context) error { payload := api.ActivityServerStartPayload{ ServerID: s.ID, diff --git a/test/server/auth_test.go b/test/server/auth_test.go new file mode 100644 index 00000000..51cf9f79 --- /dev/null +++ b/test/server/auth_test.go @@ -0,0 +1,54 @@ +package testserver + +import ( + "bytes" + "context" + "encoding/json" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" +) + +func TestAuthServer(t *testing.T) { + ctx := context.Background() + s, err := NewTestingServer(ctx, t) + require.NoError(t, err) + defer s.Shutdown(ctx) + + signup := &api.SignUp{ + Username: "testuser", + Password: "testpassword", + } + user, err := s.postAuthSignup(signup) + require.NoError(t, err) + require.Equal(t, signup.Username, user.Username) +} + +func (s *TestingServer) postAuthSignup(signup *api.SignUp) (*api.User, error) { + rawData, err := json.Marshal(&signup) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal signup") + } + reader := bytes.NewReader(rawData) + body, err := s.post("/api/auth/signup", reader, nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type AuthSignupResponse struct { + Data *api.User `json:"data"` + } + res := new(AuthSignupResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal post signup response") + } + return res.Data, nil +} diff --git a/test/server/memo_test.go b/test/server/memo_test.go new file mode 100644 index 00000000..7ee1174d --- /dev/null +++ b/test/server/memo_test.go @@ -0,0 +1,134 @@ +package testserver + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" +) + +func TestMemoServer(t *testing.T) { + ctx := context.Background() + s, err := NewTestingServer(ctx, t) + require.NoError(t, err) + defer s.Shutdown(ctx) + + signup := &api.SignUp{ + Username: "testuser", + Password: "testpassword", + } + user, err := s.postAuthSignup(signup) + require.NoError(t, err) + require.Equal(t, signup.Username, user.Username) + memoList, err := s.getMemoList() + require.NoError(t, err) + require.Len(t, memoList, 0) + memo, err := s.postMemoCreate(&api.MemoCreate{ + Content: "test memo", + }) + require.NoError(t, err) + require.Equal(t, "test memo", memo.Content) + memoList, err = s.getMemoList() + require.NoError(t, err) + require.Len(t, memoList, 1) + updatedContent := "updated memo" + memo, err = s.patchMemoPatch(&api.MemoPatch{ + ID: memo.ID, + Content: &updatedContent, + }) + require.NoError(t, err) + require.Equal(t, updatedContent, memo.Content) + err = s.postMemoDelete(&api.MemoDelete{ + ID: memo.ID, + }) + require.NoError(t, err) + memoList, err = s.getMemoList() + require.NoError(t, err) + require.Len(t, memoList, 0) +} + +func (s *TestingServer) getMemoList() ([]*api.Memo, error) { + body, err := s.get("/api/memo", nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type MemoCreateResponse struct { + Data []*api.Memo `json:"data"` + } + res := new(MemoCreateResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal get memo list response") + } + return res.Data, nil +} + +func (s *TestingServer) postMemoCreate(memoCreate *api.MemoCreate) (*api.Memo, error) { + rawData, err := json.Marshal(&memoCreate) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal memo create") + } + reader := bytes.NewReader(rawData) + body, err := s.post("/api/memo", reader, nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type MemoCreateResponse struct { + Data *api.Memo `json:"data"` + } + res := new(MemoCreateResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal post memo create response") + } + return res.Data, nil +} + +func (s *TestingServer) patchMemoPatch(memoPatch *api.MemoPatch) (*api.Memo, error) { + rawData, err := json.Marshal(&memoPatch) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal memo patch") + } + reader := bytes.NewReader(rawData) + body, err := s.patch(fmt.Sprintf("/api/memo/%d", memoPatch.ID), reader, nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type MemoPatchResponse struct { + Data *api.Memo `json:"data"` + } + res := new(MemoPatchResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal patch memo response") + } + return res.Data, nil +} + +func (s *TestingServer) postMemoDelete(memoDelete *api.MemoDelete) error { + _, err := s.delete(fmt.Sprintf("/api/memo/%d", memoDelete.ID), nil) + return err +} diff --git a/test/server/server.go b/test/server/server.go new file mode 100644 index 00000000..7a3a3b0e --- /dev/null +++ b/test/server/server.go @@ -0,0 +1,176 @@ +package testserver + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/usememos/memos/server" + "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/store/db" + "github.com/usememos/memos/test" + + // sqlite3 driver. + _ "github.com/mattn/go-sqlite3" +) + +type TestingServer struct { + server *server.Server + client *http.Client + profile *profile.Profile + cookie string +} + +func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) { + profile := test.GetTestingProfile(t) + db := db.NewDB(profile) + if err := db.Open(ctx); err != nil { + return nil, errors.Wrap(err, "failed to open db") + } + + server, err := server.NewServer(ctx, profile) + if err != nil { + return nil, errors.Wrap(err, "failed to create server") + } + + errChan := make(chan error, 1) + + s := &TestingServer{ + server: server, + client: &http.Client{}, + profile: profile, + cookie: "", + } + + go func() { + if err := s.server.Start(ctx); err != nil { + if err != http.ErrServerClosed { + errChan <- errors.Wrap(err, "failed to run main server") + } + } + }() + + if err := s.waitForServerStart(errChan); err != nil { + return nil, errors.Wrap(err, "failed to start server") + } + + return s, nil +} + +func (s *TestingServer) Shutdown(ctx context.Context) { + s.server.Shutdown(ctx) +} + +func (s *TestingServer) waitForServerStart(errChan <-chan error) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if s == nil { + continue + } + e := s.server.GetEcho() + if e == nil { + continue + } + addr := e.ListenerAddr() + if addr != nil && strings.Contains(addr.String(), ":") { + return nil // was started + } + case err := <-errChan: + if err == http.ErrServerClosed { + return nil + } + return err + } + } +} + +func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) { + fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri) + req, err := http.NewRequest(method, fullURL, body) + if err != nil { + return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL) + } + + for k, v := range header { + req.Header.Set(k, v) + } + + q := url.Values{} + for k, v := range params { + q.Add(k, v) + } + if len(q) > 0 { + req.URL.RawQuery = q.Encode() + } + + resp, err := s.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL) + } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read http response body") + } + return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body)) + } + + if method == "POST" { + if strings.Contains(uri, "/api/auth/login") || strings.Contains(uri, "/api/auth/signup") { + cookie := "" + h := resp.Header.Get("Set-Cookie") + parts := strings.Split(h, "; ") + for _, p := range parts { + if strings.HasPrefix(p, "access-token=") { + cookie = p + break + } + } + if cookie == "" { + return nil, errors.Errorf("unable to find access token in the login response headers") + } + s.cookie = cookie + } else if strings.Contains(uri, "/api/auth/logout") { + s.cookie = "" + } + } + return resp.Body, nil +} + +// get sends a GET client request. +func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) { + return s.request("GET", url, nil, params, map[string]string{ + "Cookie": s.cookie, + }) +} + +// post sends a POST client request. +func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) { + return s.request("POST", url, body, params, map[string]string{ + "Cookie": s.cookie, + }) +} + +// patch sends a PATCH client request. +func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) { + return s.request("PATCH", url, body, params, map[string]string{ + "Cookie": s.cookie, + }) +} + +// delete sends a DELETE client request. +func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) { + return s.request("DELETE", url, nil, params, map[string]string{ + "Cookie": s.cookie, + }) +} diff --git a/test/server/system_test.go b/test/server/system_test.go new file mode 100644 index 00000000..ef0a11dc --- /dev/null +++ b/test/server/system_test.go @@ -0,0 +1,58 @@ +package testserver + +import ( + "bytes" + "context" + "encoding/json" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" +) + +func TestSystemServer(t *testing.T) { + ctx := context.Background() + s, err := NewTestingServer(ctx, t) + require.NoError(t, err) + defer s.Shutdown(ctx) + + status, err := s.getSystemStatus() + require.NoError(t, err) + require.Equal(t, (*api.User)(nil), status.Host) + + signup := &api.SignUp{ + Username: "testuser", + Password: "testpassword", + } + user, err := s.postAuthSignup(signup) + require.NoError(t, err) + require.Equal(t, signup.Username, user.Username) + + status, err = s.getSystemStatus() + require.NoError(t, err) + require.Equal(t, user.ID, status.Host.ID) + require.Equal(t, user.Username, status.Host.Username) +} + +func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) { + body, err := s.get("/api/status", nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type SystemStatusResponse struct { + Data *api.SystemStatus `json:"data"` + } + res := new(SystemStatusResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal get system status response") + } + return res.Data, nil +} diff --git a/test/test.go b/test/test.go index 384be8d2..a70c2b85 100644 --- a/test/test.go +++ b/test/test.go @@ -2,19 +2,34 @@ package test import ( "fmt" + "net" "testing" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/server/version" ) +func getUnusedPort() int { + // Get a random unused port + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + defer listener.Close() + + // Get the port number + port := listener.Addr().(*net.TCPAddr).Port + return port +} + func GetTestingProfile(t *testing.T) *profile.Profile { // Get a temporary directory for the test data. dir := t.TempDir() mode := "prod" + port := getUnusedPort() return &profile.Profile{ Mode: mode, - Port: 8082, + Port: port, Data: dir, DSN: fmt.Sprintf("%s/memos_%s.db", dir, mode), Version: version.GetCurrentVersion(mode),