fix: authentication flow should abort early (#888)

* fix: finish connection once we send auth response

* removed interface for now

* handle authentication in each route group

* tags api tests

* typo

* testutil improvements

* bookmarks api auth

* cache update requires owner
This commit is contained in:
Felipe Martin 2024-04-13 19:45:03 +02:00 committed by GitHub
parent 86337a088b
commit db313f5c62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 161 additions and 39 deletions

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-shiori/shiori/internal/http/response"
"github.com/go-shiori/shiori/internal/model" "github.com/go-shiori/shiori/internal/model"
"github.com/go-shiori/shiori/internal/testutil" "github.com/go-shiori/shiori/internal/testutil"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -18,8 +19,13 @@ func TestAuthenticationRequiredMiddleware(t *testing.T) {
t.Run("test unauthorized", func(t *testing.T) { t.Run("test unauthorized", func(t *testing.T) {
g := testutil.NewGin() g := testutil.NewGin()
g.Use(AuthenticationRequired()) g.Use(AuthenticationRequired())
g.Handle("GET", "/", func(c *gin.Context) {
response.Send(c, http.StatusOK, nil)
})
w := testutil.PerformRequest(g, "GET", "/") w := testutil.PerformRequest(g, "GET", "/")
require.Equal(t, http.StatusUnauthorized, w.Code) require.Equal(t, http.StatusUnauthorized, w.Code)
// This ensures we are aborting the request and not sending more data
require.Equal(t, `{"ok":false,"message":null}`, w.Body.String())
}) })
t.Run("test authorized", func(t *testing.T) { t.Run("test authorized", func(t *testing.T) {

View file

@ -10,7 +10,7 @@ type Response struct {
Ok bool `json:"ok"` Ok bool `json:"ok"`
// Message the payload of the response, depending on the endpoint/response status // Message the payload of the response, depending on the endpoint/response status
Message interface{} `json:"message"` Message any `json:"message"`
// ErrorParams parameters defined if the response is not successful to help client's debugging // ErrorParams parameters defined if the response is not successful to help client's debugging
ErrorParams map[string]string `json:"error_params,omitempty"` ErrorParams map[string]string `json:"error_params,omitempty"`
@ -20,7 +20,11 @@ type Response struct {
} }
func (m *Response) IsError() bool { func (m *Response) IsError() bool {
return m.Ok return !m.Ok
}
func (m *Response) GetMessage() any {
return m.Message
} }
func (m *Response) Send(c *gin.Context) { func (m *Response) Send(c *gin.Context) {
@ -28,7 +32,7 @@ func (m *Response) Send(c *gin.Context) {
c.JSON(m.statusCode, m) c.JSON(m.statusCode, m)
} }
func NewResponse(ok bool, message interface{}, errorParams map[string]string, statusCode int) *Response { func NewResponse(ok bool, message any, errorParams map[string]string, statusCode int) *Response {
return &Response{ return &Response{
Ok: ok, Ok: ok,
Message: message, Message: message,

View file

@ -21,6 +21,7 @@ func Send(ctx *gin.Context, statusCode int, data interface{}) {
// SendError provides a shortcut to send an unsuccessful response // SendError provides a shortcut to send an unsuccessful response
func SendError(ctx *gin.Context, statusCode int, data interface{}) { func SendError(ctx *gin.Context, statusCode int, data interface{}) {
New(false, statusCode, data).Send(ctx) New(false, statusCode, data).Send(ctx)
ctx.Abort()
} }
// SendErrorWithParams the same as above but for errors that require error parameters // SendErrorWithParams the same as above but for errors that require error parameters

View file

@ -3,7 +3,6 @@ package api_v1
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-shiori/shiori/internal/dependencies" "github.com/go-shiori/shiori/internal/dependencies"
"github.com/go-shiori/shiori/internal/http/middleware"
"github.com/go-shiori/shiori/internal/model" "github.com/go-shiori/shiori/internal/model"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -17,10 +16,7 @@ type APIRoutes struct {
func (r *APIRoutes) Setup(g *gin.RouterGroup) model.Routes { func (r *APIRoutes) Setup(g *gin.RouterGroup) model.Routes {
// Account API handles authentication in each route // Account API handles authentication in each route
r.handle(g, "/auth", NewAuthAPIRoutes(r.logger, r.deps, r.loginHandler)) r.handle(g, "/auth", NewAuthAPIRoutes(r.logger, r.deps, r.loginHandler))
r.handle(g, "/bookmarks", NewBookmarksAPIRoutes(r.logger, r.deps))
// From here on, all routes require authentication
g.Use(middleware.AuthenticationRequired())
r.handle(g, "/bookmarks", NewBookmarksPIRoutes(r.logger, r.deps))
r.handle(g, "/tags", NewTagsPIRoutes(r.logger, r.deps)) r.handle(g, "/tags", NewTagsPIRoutes(r.logger, r.deps))
return r return r

View file

@ -13,6 +13,7 @@ import (
"github.com/go-shiori/shiori/internal/database" "github.com/go-shiori/shiori/internal/database"
"github.com/go-shiori/shiori/internal/dependencies" "github.com/go-shiori/shiori/internal/dependencies"
"github.com/go-shiori/shiori/internal/http/context" "github.com/go-shiori/shiori/internal/http/context"
"github.com/go-shiori/shiori/internal/http/middleware"
"github.com/go-shiori/shiori/internal/http/response" "github.com/go-shiori/shiori/internal/http/response"
"github.com/go-shiori/shiori/internal/model" "github.com/go-shiori/shiori/internal/model"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -24,11 +25,12 @@ type BookmarksAPIRoutes struct {
} }
func (r *BookmarksAPIRoutes) Setup(g *gin.RouterGroup) model.Routes { func (r *BookmarksAPIRoutes) Setup(g *gin.RouterGroup) model.Routes {
g.Use(middleware.AuthenticationRequired())
g.PUT("/cache", r.updateCache) g.PUT("/cache", r.updateCache)
return r return r
} }
func NewBookmarksPIRoutes(logger *logrus.Logger, deps *dependencies.Dependencies) *BookmarksAPIRoutes { func NewBookmarksAPIRoutes(logger *logrus.Logger, deps *dependencies.Dependencies) *BookmarksAPIRoutes {
return &BookmarksAPIRoutes{ return &BookmarksAPIRoutes{
logger: logger, logger: logger,
deps: deps, deps: deps,
@ -67,7 +69,7 @@ func (p *updateCachePayload) IsValid() error {
// @Router /api/v1/bookmarks/cache [put] // @Router /api/v1/bookmarks/cache [put]
func (r *BookmarksAPIRoutes) updateCache(c *gin.Context) { func (r *BookmarksAPIRoutes) updateCache(c *gin.Context) {
ctx := context.NewContextFromGin(c) ctx := context.NewContextFromGin(c)
if !ctx.UserIsLogged() { if !ctx.GetAccount().Owner {
response.SendError(c, http.StatusForbidden, nil) response.SendError(c, http.StatusForbidden, nil)
return return
} }
@ -185,7 +187,7 @@ func (r *BookmarksAPIRoutes) updateCache(c *gin.Context) {
close(chDone) close(chDone)
// Update database // Update database
_, err = r.deps.Database.SaveBookmarks(ctx, false, bookmarks...) _, err = r.deps.Database.SaveBookmarks(c, false, bookmarks...)
if err != nil { if err != nil {
r.logger.WithError(err).Error("error update bookmakrs on deatabas") r.logger.WithError(err).Error("error update bookmakrs on deatabas")
response.SendInternalServerError(c) response.SendInternalServerError(c)

View file

@ -1 +1,47 @@
package api_v1 package api_v1
import (
"context"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/go-shiori/shiori/internal/http/middleware"
"github.com/go-shiori/shiori/internal/model"
"github.com/go-shiori/shiori/internal/testutil"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func TestUpdateBookmarkCache(t *testing.T) {
logger := logrus.New()
ctx := context.TODO()
g := gin.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
g.Use(middleware.AuthMiddleware(deps))
router := NewBookmarksAPIRoutes(logger, deps)
router.Setup(g.Group("/"))
account := model.Account{
Username: "test",
Password: "test",
Owner: false,
}
require.NoError(t, deps.Database.SaveAccount(ctx, account))
token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
require.NoError(t, err)
t.Run("require authentication", func(t *testing.T) {
w := testutil.PerformRequest(g, "PUT", "/cache")
require.Equal(t, http.StatusUnauthorized, w.Code)
})
t.Run("require owner", func(t *testing.T) {
w := testutil.PerformRequest(g, "PUT", "/cache", testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token))
require.Equal(t, http.StatusForbidden, w.Code)
})
}

View file

@ -5,6 +5,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-shiori/shiori/internal/dependencies" "github.com/go-shiori/shiori/internal/dependencies"
"github.com/go-shiori/shiori/internal/http/context"
"github.com/go-shiori/shiori/internal/http/middleware"
"github.com/go-shiori/shiori/internal/http/response" "github.com/go-shiori/shiori/internal/http/response"
"github.com/go-shiori/shiori/internal/model" "github.com/go-shiori/shiori/internal/model"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -16,6 +18,7 @@ type TagsAPIRoutes struct {
} }
func (r *TagsAPIRoutes) Setup(g *gin.RouterGroup) model.Routes { func (r *TagsAPIRoutes) Setup(g *gin.RouterGroup) model.Routes {
g.Use(middleware.AuthenticationRequired())
g.GET("/", r.listHandler) g.GET("/", r.listHandler)
g.POST("/", r.createHandler) g.POST("/", r.createHandler)
return r return r
@ -47,6 +50,12 @@ func (r *TagsAPIRoutes) listHandler(c *gin.Context) {
// @Failure 403 {object} nil "Token not provided/invalid" // @Failure 403 {object} nil "Token not provided/invalid"
// @Router /api/v1/tags [post] // @Router /api/v1/tags [post]
func (r *TagsAPIRoutes) createHandler(c *gin.Context) { func (r *TagsAPIRoutes) createHandler(c *gin.Context) {
ctx := context.NewContextFromGin(c)
if !ctx.GetAccount().Owner {
response.SendError(c, http.StatusForbidden, nil)
return
}
var tag model.Tag var tag model.Tag
if err := c.BindJSON(&tag); err != nil { if err := c.BindJSON(&tag); err != nil {
response.SendError(c, http.StatusBadRequest, nil) response.SendError(c, http.StatusBadRequest, nil)

View file

@ -4,8 +4,11 @@ import (
"context" "context"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-shiori/shiori/internal/http/middleware"
"github.com/go-shiori/shiori/internal/model"
"github.com/go-shiori/shiori/internal/testutil" "github.com/go-shiori/shiori/internal/testutil"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -15,34 +18,81 @@ func TestTagList(t *testing.T) {
logger := logrus.New() logger := logrus.New()
ctx := context.TODO() ctx := context.TODO()
t.Run("empty tag list", func(t *testing.T) { g := gin.New()
g := gin.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger) _, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
router := NewTagsPIRoutes(logger, deps) g.Use(middleware.AuthMiddleware(deps))
router.Setup(g.Group("/"))
account := model.Account{
Username: "test",
Password: "test",
Owner: true,
}
require.NoError(t, deps.Database.SaveAccount(ctx, account))
token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
require.NoError(t, err)
bookmark := testutil.GetValidBookmark()
bookmark.Tags = []model.Tag{
{Name: "test"},
}
_, err = deps.Database.SaveBookmarks(ctx, true, *bookmark)
require.NoError(t, err)
router := NewTagsPIRoutes(logger, deps)
router.Setup(g.Group("/"))
t.Run("require authentication", func(t *testing.T) {
w := testutil.PerformRequest(g, "GET", "/") w := testutil.PerformRequest(g, "GET", "/")
require.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusUnauthorized, w.Code)
response, err := testutil.NewTestResponseFromReader(w.Body) response, err := testutil.NewTestResponseFromReader(w.Body)
require.NoError(t, err) require.NoError(t, err)
response.AssertMessageIsEmptyList(t) response.AssertNotOk(t)
}) })
t.Run("return tags", func(t *testing.T) { t.Run("return tags", func(t *testing.T) {
ctx := context.TODO() w := testutil.PerformRequest(g, "GET", "/", testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token))
g := gin.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
router := NewTagsPIRoutes(logger, deps)
router.Setup(g.Group("/"))
w := testutil.PerformRequest(g, "GET", "/")
require.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
response, err := testutil.NewTestResponseFromReader(w.Body) response, err := testutil.NewTestResponseFromReader(w.Body)
require.NoError(t, err) require.NoError(t, err)
response.AssertMessageIsEmptyList(t) response.AssertOk(t)
response.AssertMessageIsListLength(t, 1)
})
}
func TestTagCreate(t *testing.T) {
logger := logrus.New()
ctx := context.TODO()
g := gin.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
g.Use(middleware.AuthMiddleware(deps))
account := model.Account{
Username: "test",
Password: "test",
Owner: true,
}
require.NoError(t, deps.Database.SaveAccount(ctx, account))
// token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
// require.NoError(t, err)
router := NewTagsPIRoutes(logger, deps)
router.Setup(g.Group("/"))
t.Run("require authentication", func(t *testing.T) {
w := testutil.PerformRequest(g, "POST", "/")
require.Equal(t, http.StatusUnauthorized, w.Code)
})
t.Run("create tag", func(t *testing.T) {
// TODO: Implement this test
// Tags require a bookmark to be created, so we need to create a bookmark first
// but I'm not sure if we should enforce this.
}) })
} }

View file

@ -11,42 +11,50 @@ import (
) )
type testResponse struct { type testResponse struct {
Response *response.Response Response response.Response
} }
func (r *testResponse) AssertMessageIsEmptyList(t *testing.T) { func (r *testResponse) AssertMessageIsEmptyList(t *testing.T) {
require.Equal(t, []interface{}{}, r.Response.Message) require.Equal(t, []interface{}{}, r.Response.GetMessage())
} }
func (r *testResponse) AssertNilMessage(t *testing.T) { func (r *testResponse) AssertNilMessage(t *testing.T) {
require.Equal(t, nil, r.Response.Message) require.Equal(t, nil, r.Response.GetMessage())
} }
func (r testResponse) AssertMessageEquals(t *testing.T, expected interface{}) { func (r testResponse) AssertMessageEquals(t *testing.T, expected interface{}) {
require.Equal(t, expected, r.Response.Message) require.Equal(t, expected, r.Response.GetMessage())
}
func (r *testResponse) AssertMessageIsListLength(t *testing.T, length int) {
require.Len(t, r.Response.GetMessage(), length)
} }
func (r *testResponse) AssertOk(t *testing.T) { func (r *testResponse) AssertOk(t *testing.T) {
require.True(t, r.Response.Ok) require.False(t, r.Response.IsError())
} }
func (r *testResponse) AssertNotOk(t *testing.T) { func (r *testResponse) AssertNotOk(t *testing.T) {
require.False(t, r.Response.Ok) require.True(t, r.Response.IsError())
}
func (r *testResponse) Assert(t *testing.T, fn func(t *testing.T, r *testResponse)) {
fn(t, r)
} }
func NewTestResponseFromBytes(b []byte) (*testResponse, error) { func NewTestResponseFromBytes(b []byte) (*testResponse, error) {
r := testResponse{} tr := testResponse{}
if err := json.Unmarshal(b, &r.Response); err != nil { if err := json.Unmarshal(b, &tr.Response); err != nil {
return nil, errors.Wrap(err, "error parsing response") return nil, errors.Wrap(err, "error parsing response")
} }
return &r, nil return &tr, nil
} }
func NewTestResponseFromReader(r io.Reader) (*testResponse, error) { func NewTestResponseFromReader(r io.Reader) (*testResponse, error) {
response := testResponse{} tr := testResponse{}
decoder := json.NewDecoder(r) decoder := json.NewDecoder(r)
if err := decoder.Decode(&response.Response); err != nil { if err := decoder.Decode(&tr.Response); err != nil {
return nil, errors.Wrap(err, "error parsing response") return nil, errors.Wrap(err, "error parsing response")
} }
return &response, nil return &tr, nil
} }