feat: server tests (#1556)

* feat: server tests

* chore: update
This commit is contained in:
boojack 2023-04-17 21:34:59 +08:00 committed by GitHub
parent e62a94c05a
commit 994d5dd891
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 442 additions and 7 deletions

View file

@ -191,7 +191,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
if generateToken { if generateToken {
generateTokenFunc := func() error { generateTokenFunc := func() error {
rc, err := c.Cookie(auth.RefreshTokenCookieName) rc, err := c.Cookie(auth.RefreshTokenCookieName)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.")
} }

View file

@ -53,11 +53,6 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
e.Use(middleware.Gzip()) e.Use(middleware.Gzip())
e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: s.defaultAuthSkipper,
TokenLookup: "cookie:_csrf",
}))
e.Use(middleware.CORS()) e.Use(middleware.CORS())
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
@ -141,6 +136,10 @@ func (s *Server) Shutdown(ctx context.Context) {
fmt.Printf("memos stopped properly\n") fmt.Printf("memos stopped properly\n")
} }
func (s *Server) GetEcho() *echo.Echo {
return s.e
}
func (s *Server) createServerStartActivity(ctx context.Context) error { func (s *Server) createServerStartActivity(ctx context.Context) error {
payload := api.ActivityServerStartPayload{ payload := api.ActivityServerStartPayload{
ServerID: s.ID, ServerID: s.ID,

54
test/server/auth_test.go Normal file
View file

@ -0,0 +1,54 @@
package testserver
import (
"bytes"
"context"
"encoding/json"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
)
func TestAuthServer(t *testing.T) {
ctx := context.Background()
s, err := NewTestingServer(ctx, t)
require.NoError(t, err)
defer s.Shutdown(ctx)
signup := &api.SignUp{
Username: "testuser",
Password: "testpassword",
}
user, err := s.postAuthSignup(signup)
require.NoError(t, err)
require.Equal(t, signup.Username, user.Username)
}
func (s *TestingServer) postAuthSignup(signup *api.SignUp) (*api.User, error) {
rawData, err := json.Marshal(&signup)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal signup")
}
reader := bytes.NewReader(rawData)
body, err := s.post("/api/auth/signup", reader, nil)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(body)
if err != nil {
return nil, errors.Wrap(err, "fail to read response body")
}
type AuthSignupResponse struct {
Data *api.User `json:"data"`
}
res := new(AuthSignupResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal post signup response")
}
return res.Data, nil
}

134
test/server/memo_test.go Normal file
View file

@ -0,0 +1,134 @@
package testserver
import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
)
func TestMemoServer(t *testing.T) {
ctx := context.Background()
s, err := NewTestingServer(ctx, t)
require.NoError(t, err)
defer s.Shutdown(ctx)
signup := &api.SignUp{
Username: "testuser",
Password: "testpassword",
}
user, err := s.postAuthSignup(signup)
require.NoError(t, err)
require.Equal(t, signup.Username, user.Username)
memoList, err := s.getMemoList()
require.NoError(t, err)
require.Len(t, memoList, 0)
memo, err := s.postMemoCreate(&api.MemoCreate{
Content: "test memo",
})
require.NoError(t, err)
require.Equal(t, "test memo", memo.Content)
memoList, err = s.getMemoList()
require.NoError(t, err)
require.Len(t, memoList, 1)
updatedContent := "updated memo"
memo, err = s.patchMemoPatch(&api.MemoPatch{
ID: memo.ID,
Content: &updatedContent,
})
require.NoError(t, err)
require.Equal(t, updatedContent, memo.Content)
err = s.postMemoDelete(&api.MemoDelete{
ID: memo.ID,
})
require.NoError(t, err)
memoList, err = s.getMemoList()
require.NoError(t, err)
require.Len(t, memoList, 0)
}
func (s *TestingServer) getMemoList() ([]*api.Memo, error) {
body, err := s.get("/api/memo", nil)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(body)
if err != nil {
return nil, errors.Wrap(err, "fail to read response body")
}
type MemoCreateResponse struct {
Data []*api.Memo `json:"data"`
}
res := new(MemoCreateResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal get memo list response")
}
return res.Data, nil
}
func (s *TestingServer) postMemoCreate(memoCreate *api.MemoCreate) (*api.Memo, error) {
rawData, err := json.Marshal(&memoCreate)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal memo create")
}
reader := bytes.NewReader(rawData)
body, err := s.post("/api/memo", reader, nil)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(body)
if err != nil {
return nil, errors.Wrap(err, "fail to read response body")
}
type MemoCreateResponse struct {
Data *api.Memo `json:"data"`
}
res := new(MemoCreateResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal post memo create response")
}
return res.Data, nil
}
func (s *TestingServer) patchMemoPatch(memoPatch *api.MemoPatch) (*api.Memo, error) {
rawData, err := json.Marshal(&memoPatch)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal memo patch")
}
reader := bytes.NewReader(rawData)
body, err := s.patch(fmt.Sprintf("/api/memo/%d", memoPatch.ID), reader, nil)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(body)
if err != nil {
return nil, errors.Wrap(err, "fail to read response body")
}
type MemoPatchResponse struct {
Data *api.Memo `json:"data"`
}
res := new(MemoPatchResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal patch memo response")
}
return res.Data, nil
}
func (s *TestingServer) postMemoDelete(memoDelete *api.MemoDelete) error {
_, err := s.delete(fmt.Sprintf("/api/memo/%d", memoDelete.ID), nil)
return err
}

176
test/server/server.go Normal file
View file

@ -0,0 +1,176 @@
package testserver
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/server"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store/db"
"github.com/usememos/memos/test"
// sqlite3 driver.
_ "github.com/mattn/go-sqlite3"
)
type TestingServer struct {
server *server.Server
client *http.Client
profile *profile.Profile
cookie string
}
func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
profile := test.GetTestingProfile(t)
db := db.NewDB(profile)
if err := db.Open(ctx); err != nil {
return nil, errors.Wrap(err, "failed to open db")
}
server, err := server.NewServer(ctx, profile)
if err != nil {
return nil, errors.Wrap(err, "failed to create server")
}
errChan := make(chan error, 1)
s := &TestingServer{
server: server,
client: &http.Client{},
profile: profile,
cookie: "",
}
go func() {
if err := s.server.Start(ctx); err != nil {
if err != http.ErrServerClosed {
errChan <- errors.Wrap(err, "failed to run main server")
}
}
}()
if err := s.waitForServerStart(errChan); err != nil {
return nil, errors.Wrap(err, "failed to start server")
}
return s, nil
}
func (s *TestingServer) Shutdown(ctx context.Context) {
s.server.Shutdown(ctx)
}
func (s *TestingServer) waitForServerStart(errChan <-chan error) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if s == nil {
continue
}
e := s.server.GetEcho()
if e == nil {
continue
}
addr := e.ListenerAddr()
if addr != nil && strings.Contains(addr.String(), ":") {
return nil // was started
}
case err := <-errChan:
if err == http.ErrServerClosed {
return nil
}
return err
}
}
}
func (s *TestingServer) request(method, uri string, body io.Reader, params, header map[string]string) (io.ReadCloser, error) {
fullURL := fmt.Sprintf("http://localhost:%d%s", s.profile.Port, uri)
req, err := http.NewRequest(method, fullURL, body)
if err != nil {
return nil, errors.Wrapf(err, "fail to create a new %s request(%q)", method, fullURL)
}
for k, v := range header {
req.Header.Set(k, v)
}
q := url.Values{}
for k, v := range params {
q.Add(k, v)
}
if len(q) > 0 {
req.URL.RawQuery = q.Encode()
}
resp, err := s.client.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "fail to send a %s request(%q)", method, fullURL)
}
if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read http response body")
}
return nil, errors.Errorf("http response error code %v body %q", resp.StatusCode, string(body))
}
if method == "POST" {
if strings.Contains(uri, "/api/auth/login") || strings.Contains(uri, "/api/auth/signup") {
cookie := ""
h := resp.Header.Get("Set-Cookie")
parts := strings.Split(h, "; ")
for _, p := range parts {
if strings.HasPrefix(p, "access-token=") {
cookie = p
break
}
}
if cookie == "" {
return nil, errors.Errorf("unable to find access token in the login response headers")
}
s.cookie = cookie
} else if strings.Contains(uri, "/api/auth/logout") {
s.cookie = ""
}
}
return resp.Body, nil
}
// get sends a GET client request.
func (s *TestingServer) get(url string, params map[string]string) (io.ReadCloser, error) {
return s.request("GET", url, nil, params, map[string]string{
"Cookie": s.cookie,
})
}
// post sends a POST client request.
func (s *TestingServer) post(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
return s.request("POST", url, body, params, map[string]string{
"Cookie": s.cookie,
})
}
// patch sends a PATCH client request.
func (s *TestingServer) patch(url string, body io.Reader, params map[string]string) (io.ReadCloser, error) {
return s.request("PATCH", url, body, params, map[string]string{
"Cookie": s.cookie,
})
}
// delete sends a DELETE client request.
func (s *TestingServer) delete(url string, params map[string]string) (io.ReadCloser, error) {
return s.request("DELETE", url, nil, params, map[string]string{
"Cookie": s.cookie,
})
}

View file

@ -0,0 +1,58 @@
package testserver
import (
"bytes"
"context"
"encoding/json"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
)
func TestSystemServer(t *testing.T) {
ctx := context.Background()
s, err := NewTestingServer(ctx, t)
require.NoError(t, err)
defer s.Shutdown(ctx)
status, err := s.getSystemStatus()
require.NoError(t, err)
require.Equal(t, (*api.User)(nil), status.Host)
signup := &api.SignUp{
Username: "testuser",
Password: "testpassword",
}
user, err := s.postAuthSignup(signup)
require.NoError(t, err)
require.Equal(t, signup.Username, user.Username)
status, err = s.getSystemStatus()
require.NoError(t, err)
require.Equal(t, user.ID, status.Host.ID)
require.Equal(t, user.Username, status.Host.Username)
}
func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) {
body, err := s.get("/api/status", nil)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(body)
if err != nil {
return nil, errors.Wrap(err, "fail to read response body")
}
type SystemStatusResponse struct {
Data *api.SystemStatus `json:"data"`
}
res := new(SystemStatusResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal get system status response")
}
return res.Data, nil
}

View file

@ -2,19 +2,34 @@ package test
import ( import (
"fmt" "fmt"
"net"
"testing" "testing"
"github.com/usememos/memos/server/profile" "github.com/usememos/memos/server/profile"
"github.com/usememos/memos/server/version" "github.com/usememos/memos/server/version"
) )
func getUnusedPort() int {
// Get a random unused port
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
defer listener.Close()
// Get the port number
port := listener.Addr().(*net.TCPAddr).Port
return port
}
func GetTestingProfile(t *testing.T) *profile.Profile { func GetTestingProfile(t *testing.T) *profile.Profile {
// Get a temporary directory for the test data. // Get a temporary directory for the test data.
dir := t.TempDir() dir := t.TempDir()
mode := "prod" mode := "prod"
port := getUnusedPort()
return &profile.Profile{ return &profile.Profile{
Mode: mode, Mode: mode,
Port: 8082, Port: port,
Data: dir, Data: dir,
DSN: fmt.Sprintf("%s/memos_%s.db", dir, mode), DSN: fmt.Sprintf("%s/memos_%s.db", dir, mode),
Version: version.GetCurrentVersion(mode), Version: version.GetCurrentVersion(mode),