From 83d3af3527458652dbbe1d1a25d26249119235fb Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Mon, 2 Sep 2024 21:55:43 +0530 Subject: [PATCH] Fix list auth by adding an explicit 'getAll' flag to query. --- cmd/lists.go | 13 +++++++++---- cmd/public.go | 4 ++-- internal/core/lists.go | 10 +++++----- queries.sql | 6 +++--- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/cmd/lists.go b/cmd/lists.go index 051b23b4..0570215b 100644 --- a/cmd/lists.go +++ b/cmd/lists.go @@ -28,14 +28,19 @@ func handleGetLists(c echo.Context) error { out models.PageResults ) - var permittedIDs []int - if _, ok := user.PermissionsMap["lists:get_all"]; !ok { + var ( + permittedIDs []int + getAll = false + ) + if _, ok := user.PermissionsMap["lists:get_all"]; ok { + getAll = true + } else { permittedIDs = user.GetListIDs } // Minimal query simply returns the list of all lists without JOIN subscriber counts. This is fast. if minimal { - res, err := app.core.GetLists("", permittedIDs) + res, err := app.core.GetLists("", getAll, permittedIDs) if err != nil { return err } @@ -53,7 +58,7 @@ func handleGetLists(c echo.Context) error { } // Full list query. - res, total, err := app.core.QueryLists(query, typ, optin, tags, orderBy, order, permittedIDs, pg.Offset, pg.Limit) + res, total, err := app.core.QueryLists(query, typ, optin, tags, orderBy, order, getAll, permittedIDs, pg.Offset, pg.Limit) if err != nil { return err } diff --git a/cmd/public.go b/cmd/public.go index 79654093..3a408441 100644 --- a/cmd/public.go +++ b/cmd/public.go @@ -115,7 +115,7 @@ func handleGetPublicLists(c echo.Context) error { ) // Get all public lists. - lists, err := app.core.GetLists(models.ListTypePublic, nil) + lists, err := app.core.GetLists(models.ListTypePublic, true, nil) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, app.i18n.T("public.errorFetchingLists")) } @@ -418,7 +418,7 @@ func handleSubscriptionFormPage(c echo.Context) error { } // Get all public lists. - lists, err := app.core.GetLists(models.ListTypePublic, nil) + lists, err := app.core.GetLists(models.ListTypePublic, true, nil) if err != nil { return c.Render(http.StatusInternalServerError, tplMessage, makeMsgTpl(app.i18n.T("public.errorTitle"), "", app.i18n.Ts("public.errorFetchingLists"))) diff --git a/internal/core/lists.go b/internal/core/lists.go index e2fabbf1..0b8c7285 100644 --- a/internal/core/lists.go +++ b/internal/core/lists.go @@ -10,10 +10,10 @@ import ( ) // GetLists gets all lists optionally filtered by type. -func (c *Core) GetLists(typ string, permittedIDs []int) ([]models.List, error) { +func (c *Core) GetLists(typ string, getAll bool, permittedIDs []int) ([]models.List, error) { out := []models.List{} - if err := c.q.GetLists.Select(&out, typ, "id", pq.Array(permittedIDs)); err != nil { + if err := c.q.GetLists.Select(&out, typ, "id", getAll, pq.Array(permittedIDs)); err != nil { c.log.Printf("error fetching lists: %v", err) return nil, echo.NewHTTPError(http.StatusInternalServerError, c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.lists}", "error", pqErrMsg(err))) @@ -36,7 +36,7 @@ func (c *Core) GetLists(typ string, permittedIDs []int) ([]models.List, error) { // QueryLists gets multiple lists based on multiple query params. Along with the paginated and sliced // results, the total number of lists in the DB is returned. -func (c *Core) QueryLists(searchStr, typ, optin string, tags []string, orderBy, order string, permittedIDs []int, offset, limit int) ([]models.List, int, error) { +func (c *Core) QueryLists(searchStr, typ, optin string, tags []string, orderBy, order string, getAll bool, permittedIDs []int, offset, limit int) ([]models.List, int, error) { _ = c.refreshCache(matListSubStats, false) if tags == nil { @@ -47,7 +47,7 @@ func (c *Core) QueryLists(searchStr, typ, optin string, tags []string, orderBy, out = []models.List{} queryStr, stmt = makeSearchQuery(searchStr, orderBy, order, c.q.QueryLists, listQuerySortFields) ) - if err := c.db.Select(&out, stmt, 0, "", queryStr, typ, optin, pq.StringArray(tags), pq.Array(permittedIDs), offset, limit); err != nil { + if err := c.db.Select(&out, stmt, 0, "", queryStr, typ, optin, pq.StringArray(tags), getAll, pq.Array(permittedIDs), offset, limit); err != nil { c.log.Printf("error fetching lists: %v", err) return nil, 0, echo.NewHTTPError(http.StatusInternalServerError, c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.lists}", "error", pqErrMsg(err))) @@ -82,7 +82,7 @@ func (c *Core) GetList(id int, uuid string) (models.List, error) { var res []models.List queryStr, stmt := makeSearchQuery("", "", "", c.q.QueryLists, nil) - if err := c.db.Select(&res, stmt, id, uu, queryStr, "", "", pq.StringArray{}, nil, 0, 1); err != nil { + if err := c.db.Select(&res, stmt, id, uu, queryStr, "", "", pq.StringArray{}, true, nil, 0, 1); err != nil { c.log.Printf("error fetching lists: %v", err) return models.List{}, echo.NewHTTPError(http.StatusInternalServerError, c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.lists}", "error", pqErrMsg(err))) diff --git a/queries.sql b/queries.sql index af338e2a..eae16fe2 100644 --- a/queries.sql +++ b/queries.sql @@ -423,7 +423,7 @@ UPDATE subscriber_lists SET status='unsubscribed', updated_at=NOW() SELECT * FROM lists WHERE (CASE WHEN $1 = '' THEN 1=1 ELSE type=$1::list_type END) AND CASE -- Optional list IDs based on user permission. - WHEN $3::INT[] IS NULL THEN TRUE ELSE id = ANY($3) + WHEN $3 = TRUE THEN TRUE ELSE id = ANY($4::INT[]) END ORDER BY CASE WHEN $2 = 'id' THEN id END, CASE WHEN $2 = 'name' THEN name END; @@ -441,9 +441,9 @@ WITH ls AS ( AND (CARDINALITY($6::VARCHAR(100)[]) = 0 OR $6 <@ tags) AND CASE -- Optional list IDs based on user permission. - WHEN $7::INT[] IS NULL THEN TRUE ELSE id = ANY($7) + WHEN $7 = TRUE THEN TRUE ELSE id = ANY($8::INT[]) END - OFFSET $8 LIMIT (CASE WHEN $9 < 1 THEN NULL ELSE $9 END) + OFFSET $9 LIMIT (CASE WHEN $10 < 1 THEN NULL ELSE $10 END) ), statuses AS ( SELECT