Merge branch 'fix-sql-search'

This commit is contained in:
Kailash Nadh 2025-04-18 15:33:41 +05:30
commit b44ea0c336
7 changed files with 309 additions and 116 deletions

View file

@ -26,6 +26,7 @@ const (
// subQueryReq is a "catch all" struct for reading various
// subscriber related requests.
type subQueryReq struct {
Search string `json:"search"`
Query string `json:"query"`
ListIDs []int `json:"list_ids"`
TargetListIDs []int `json:"target_list_ids"`
@ -84,21 +85,32 @@ func (a *App) QuerySubscribers(c echo.Context) error {
return err
}
// Does the user have the subscribers:sql_query permission?
query := formatSQLExp(c.FormValue("query"))
if query != "" {
if !user.HasPerm(auth.PermSubscribersSqlQuery) {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", auth.PermSubscribersSqlQuery))
}
}
var (
// The "WHERE ?" bit.
query = sanitizeSQLExp(c.FormValue("query"))
searchStr = strings.TrimSpace(c.FormValue("search"))
subStatus = c.FormValue("subscription_status")
orderBy = c.FormValue("order_by")
order = c.FormValue("order")
orderBy = c.FormValue("order_by")
pg = a.pg.NewFromURL(c.Request().URL.Query())
)
res, total, err := a.core.QuerySubscribers(query, listIDs, subStatus, order, orderBy, pg.Offset, pg.Limit)
// Query subscribers from the DB.
res, total, err := a.core.QuerySubscribers(searchStr, query, listIDs, subStatus, order, orderBy, pg.Offset, pg.Limit)
if err != nil {
return err
}
out := models.PageResults{
Query: query,
Search: searchStr,
Results: res,
Total: total,
Page: pg.Page,
@ -128,9 +140,20 @@ func (a *App) ExportSubscribers(c echo.Context) error {
// Filter by subscription status
subStatus := c.QueryParam("subscription_status")
// Does the user have the subscribers:sql_query permission?
var (
searchStr = strings.TrimSpace(c.FormValue("search"))
query = formatSQLExp(c.FormValue("query"))
)
if query != "" {
if !user.HasPerm(auth.PermSubscribersSqlQuery) {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", auth.PermSubscribersSqlQuery))
}
}
// Get the batched export iterator.
query := sanitizeSQLExp(c.FormValue("query"))
exp, err := a.core.ExportSubscribers(query, subIDs, listIDs, subStatus, a.cfg.DBBatchSize)
exp, err := a.core.ExportSubscribers(searchStr, query, subIDs, listIDs, subStatus, a.cfg.DBBatchSize)
if err != nil {
return err
}
@ -325,7 +348,13 @@ func (a *App) ManageSubscriberLists(c echo.Context) error {
}
// Filter lists against the current user's permitted lists.
listIDs := user.FilterListsByPerm(auth.PermTypeManage, req.TargetListIDs)
listIDs := user.FilterListsByPerm(auth.PermTypeGet|auth.PermTypeManage, req.TargetListIDs)
// User doesn't have the required list permissions.
if len(listIDs) == 0 {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", "lists"))
}
// Run the action in the DB.
var err error
@ -382,20 +411,34 @@ func (a *App) DeleteSubscribers(c echo.Context) error {
// DeleteSubscribersByQuery bulk deletes based on an
// arbitrary SQL expression.
func (a *App) DeleteSubscribersByQuery(c echo.Context) error {
// Get the authenticated user.
user := auth.GetUser(c)
var req subQueryReq
if err := c.Bind(&req); err != nil {
return err
}
req.Search = strings.TrimSpace(req.Search)
req.Query = formatSQLExp(req.Query)
if req.All {
// If the "all" flag is set, ignore any subquery that may be present.
req.Search = ""
req.Query = ""
} else if req.Query == "" {
} else if req.Search == "" && req.Query == "" {
return echo.NewHTTPError(http.StatusBadRequest, a.i18n.Ts("globals.messages.invalidFields", "name", "query"))
}
// Does the user have the subscribers:sql_query permission?
if req.Query != "" {
if !user.HasPerm(auth.PermSubscribersSqlQuery) {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", auth.PermSubscribersSqlQuery))
}
}
// Delete the subscribers from the DB.
if err := a.core.DeleteSubscribersByQuery(req.Query, req.ListIDs, req.SubscriptionStatus); err != nil {
if err := a.core.DeleteSubscribersByQuery(req.Search, req.Query, req.ListIDs, req.SubscriptionStatus); err != nil {
return err
}
@ -405,17 +448,30 @@ func (a *App) DeleteSubscribersByQuery(c echo.Context) error {
// BlocklistSubscribersByQuery bulk blocklists subscribers
// based on an arbitrary SQL expression.
func (a *App) BlocklistSubscribersByQuery(c echo.Context) error {
// Get the authenticated user.
user := auth.GetUser(c)
var req subQueryReq
if err := c.Bind(&req); err != nil {
return err
}
if req.Query == "" {
req.Search = strings.TrimSpace(req.Search)
req.Query = formatSQLExp(req.Query)
if req.Search == "" && req.Query == "" {
return echo.NewHTTPError(http.StatusBadRequest, a.i18n.Ts("globals.messages.invalidFields", "name", "query"))
}
// Does the user have the subscribers:sql_query permission?
if req.Query != "" {
if !user.HasPerm(auth.PermSubscribersSqlQuery) {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", auth.PermSubscribersSqlQuery))
}
}
// Update the subscribers in the DB.
if err := a.core.BlocklistSubscribersByQuery(req.Query, req.ListIDs, req.SubscriptionStatus); err != nil {
if err := a.core.BlocklistSubscribersByQuery(req.Search, req.Query, req.ListIDs, req.SubscriptionStatus); err != nil {
return err
}
@ -437,19 +493,33 @@ func (a *App) ManageSubscriberListsByQuery(c echo.Context) error {
a.i18n.T("subscribers.errorNoListsGiven"))
}
req.Search = strings.TrimSpace(req.Search)
req.Query = formatSQLExp(req.Query)
if req.Search == "" && req.Query == "" {
return echo.NewHTTPError(http.StatusBadRequest, a.i18n.Ts("globals.messages.invalidFields", "name", "query"))
}
// Does the user have the subscribers:sql_query permission?
if req.Query != "" {
if !user.HasPerm(auth.PermSubscribersSqlQuery) {
return echo.NewHTTPError(http.StatusForbidden,
a.i18n.Ts("globals.messages.permissionDenied", "name", auth.PermSubscribersSqlQuery))
}
}
// Filter lists against the current user's permitted lists.
sourceListIDs := user.FilterListsByPerm(auth.PermTypeManage, req.ListIDs)
targetListIDs := user.FilterListsByPerm(auth.PermTypeManage, req.TargetListIDs)
sourceListIDs := user.FilterListsByPerm(auth.PermTypeGet|auth.PermTypeManage, req.ListIDs)
targetListIDs := user.FilterListsByPerm(auth.PermTypeGet|auth.PermTypeManage, req.TargetListIDs)
// Run the action in the DB.
var err error
switch req.Action {
case "add":
err = a.core.AddSubscriptionsByQuery(req.Query, sourceListIDs, targetListIDs, req.Status, req.SubscriptionStatus)
err = a.core.AddSubscriptionsByQuery(req.Search, req.Query, sourceListIDs, targetListIDs, req.Status, req.SubscriptionStatus)
case "remove":
err = a.core.DeleteSubscriptionsByQuery(req.Query, sourceListIDs, targetListIDs, req.SubscriptionStatus)
err = a.core.DeleteSubscriptionsByQuery(req.Search, req.Query, sourceListIDs, targetListIDs, req.SubscriptionStatus)
case "unsubscribe":
err = a.core.UnsubscribeListsByQuery(req.Query, sourceListIDs, targetListIDs, req.SubscriptionStatus)
err = a.core.UnsubscribeListsByQuery(req.Search, req.Query, sourceListIDs, targetListIDs, req.SubscriptionStatus)
default:
return echo.NewHTTPError(http.StatusBadRequest, a.i18n.T("subscribers.invalidAction"))
}
@ -583,15 +653,15 @@ func (a *App) filterListQueryByPerm(param string, qp url.Values, user auth.User)
return listIDs, nil
}
// sanitizeSQLExp does basic sanitisation on arbitrary
// formatSQLExp does basic sanitisation on arbitrary
// SQL query expressions coming from the frontend.
func sanitizeSQLExp(q string) string {
func formatSQLExp(q string) string {
q = strings.TrimSpace(q)
if len(q) == 0 {
return ""
}
// Remove semicolon suffix.
q = strings.TrimSpace(q)
if q[len(q)-1] == ';' {
q = q[:len(q)-1]
}

View file

@ -223,6 +223,7 @@ export default Vue.extend({
queryParams: {
// Search query expression.
queryExp: '',
search: '',
// ID of the list the current subscriber view is filtered by.
listID: null,
@ -242,6 +243,7 @@ export default Vue.extend({
toggleAdvancedSearch() {
this.isSearchAdvanced = !this.isSearchAdvanced;
this.queryParams.search = '';
// Toggling to simple search.
if (!this.isSearchAdvanced) {
@ -253,6 +255,16 @@ export default Vue.extend({
return;
}
// Toggling to advanced search.
const q = this.queryInput.replace(/'/, "''").trim();
if (q) {
if (this.$utils.validateEmail(q)) {
this.queryParams.queryExp = `email = '${q.toLowerCase()}'`;
} else {
this.queryParams.queryExp = `(name ~* '${q}' OR email ~* '${q.toLowerCase()}')`;
}
}
// Toggling to advanced search.
this.$nextTick(() => {
this.$refs.queryExp.focus();
@ -307,13 +319,9 @@ export default Vue.extend({
// in this.queryExp.
onSimpleQueryInput(v) {
const q = v.replace(/'/, "''").trim();
this.queryParams.queryExp = '';
this.queryParams.page = 1;
if (this.$utils.validateEmail(q)) {
this.queryParams.queryExp = `email = '${q.toLowerCase()}'`;
} else {
this.queryParams.queryExp = `(name ~* '${q}' OR email ~* '${q.toLowerCase()}')`;
}
this.queryParams.search = q.toLowerCase();
},
// Ctrl + Enter on the advanced query searches.
@ -331,15 +339,24 @@ export default Vue.extend({
querySubscribers(params) {
this.queryParams = { ...this.queryParams, ...params };
const qp = {
list_id: this.queryParams.listID,
search: this.queryParams.search,
query: this.queryParams.queryExp,
page: this.queryParams.page,
subscription_status: this.queryParams.subStatus,
order_by: this.queryParams.orderBy,
order: this.queryParams.order,
};
if (this.queryParams.queryExp) {
delete qp.search;
} else {
delete qp.queryExp;
}
this.$nextTick(() => {
this.$api.getSubscribers({
list_id: this.queryParams.listID,
query: this.queryParams.queryExp,
page: this.queryParams.page,
subscription_status: this.queryParams.subStatus,
order_by: this.queryParams.orderBy,
order: this.queryParams.order,
}).then(() => {
this.$api.getSubscribers(qp).then(() => {
this.bulk.checked = [];
});
});
@ -371,6 +388,7 @@ export default Vue.extend({
// 'All' is selected, blocklist by query.
fn = () => {
this.$api.blocklistSubscribersByQuery({
search: this.queryParams.search,
query: this.queryParams.queryExp,
list_ids: this.queryParams.listID ? [this.queryParams.listID] : null,
subscription_status: this.queryParams.subStatus,
@ -387,7 +405,12 @@ export default Vue.extend({
this.$utils.confirm(this.$t('subscribers.confirmExport', { num }), () => {
const q = new URLSearchParams();
q.append('query', this.queryParams.queryExp);
if (this.queryParams.search) {
q.append('search', this.queryParams.search);
} else if (this.queryParams.queryExp) {
q.append('query', this.queryParams.queryExp);
}
if (this.queryParams.listID) {
q.append('list_id', this.queryParams.listID);
@ -426,6 +449,7 @@ export default Vue.extend({
// If the query expression is empty, explicitly pass `all=true`
// so that the backend deletes all records in the DB with an empty query string.
all: this.queryParams.queryExp.trim() === '',
search: this.queryParams.search,
query: this.queryParams.queryExp,
list_ids: this.queryParams.listID ? [this.queryParams.listID] : null,
subscription_status: this.queryParams.subStatus,
@ -447,6 +471,7 @@ export default Vue.extend({
const data = {
action,
query: this.fullQueryExp,
search: this.queryParams.search,
list_ids: this.queryParams.listID ? [this.queryParams.listID] : null,
target_list_ids: lists.map((l) => l.id),
};

View file

@ -9,12 +9,27 @@ import (
"strings"
"github.com/gofrs/uuid/v5"
"github.com/jmoiron/sqlx"
"github.com/knadh/listmonk/internal/auth"
"github.com/knadh/listmonk/models"
"github.com/labstack/echo/v4"
"github.com/lib/pq"
)
var (
allowedSubQueryTables = map[string]struct{}{
"subscribers": {},
"lists": {},
"subscribers_lists": {},
"campaigns": {},
"campaign_lists": {},
"campaign_views": {},
"links": {},
"link_clicks": {},
"bounces": {},
}
)
// GetSubscriber fetches a subscriber by one of the given params.
func (c *Core) GetSubscriber(id int, uuid, email string) (models.Subscriber, error) {
var uu any
@ -88,13 +103,7 @@ func (c *Core) GetSubscribersByEmail(emails []string) (models.Subscribers, error
}
// QuerySubscribers queries and returns paginated subscrribers based on the given params including the total count.
func (c *Core) QuerySubscribers(query string, listIDs []int, subStatus string, order, orderBy string, offset, limit int) (models.Subscribers, int, error) {
// There's an arbitrary query condition.
cond := ""
if query != "" {
cond = " AND " + query
}
func (c *Core) QuerySubscribers(searchStr, queryExp string, listIDs []int, subStatus string, order, orderBy string, offset, limit int) (models.Subscribers, int, error) {
// Sort params.
if !strSliceContains(orderBy, subQuerySortFields) {
orderBy = "subscribers.id"
@ -108,10 +117,28 @@ func (c *Core) QuerySubscribers(query string, listIDs []int, subStatus string, o
listIDs = []int{}
}
// There's an arbitrary query condition.
cond := "TRUE"
if queryExp != "" {
cond = queryExp
}
// stmt is the raw SQL query.
stmt := strings.ReplaceAll(c.q.QuerySubscribers, "%query%", cond)
stmt = strings.ReplaceAll(stmt, "%order%", orderBy+" "+order)
// Validate the tables used in the query.
if err := validateQueryTables(c.db, stmt, allowedSubQueryTables); err != nil {
c.log.Printf("error validating query tables: %v", err)
return nil, 0, echo.NewHTTPError(http.StatusBadRequest,
c.i18n.Ts("subscribers.errorPreparingQuery", "error", err.Error()))
}
// Create a readonly transaction that just does COUNT() to obtain the count of results
// and to ensure that the arbitrary query is indeed readonly.
total, err := c.getSubscriberCount(cond, subStatus, listIDs)
total, err := c.getSubscriberCount(searchStr, cond, subStatus, listIDs)
if err != nil {
c.log.Printf("error getting subscriber count: %v", err)
return nil, 0, err
}
@ -120,10 +147,6 @@ func (c *Core) QuerySubscribers(query string, listIDs []int, subStatus string, o
return models.Subscribers{}, 0, nil
}
// Run the query again and fetch the actual data. stmt is the raw SQL query.
stmt := strings.ReplaceAll(c.q.QuerySubscribers, "%query%", cond)
stmt = strings.ReplaceAll(stmt, "%order%", orderBy+" "+order)
tx, err := c.db.BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
c.log.Printf("error preparing subscriber query: %v", err)
@ -132,7 +155,7 @@ func (c *Core) QuerySubscribers(query string, listIDs []int, subStatus string, o
defer tx.Rollback()
var out models.Subscribers
if err := tx.Select(&out, stmt, pq.Array(listIDs), subStatus, offset, limit); err != nil {
if err := tx.Select(&out, stmt, pq.Array(listIDs), subStatus, searchStr, offset, limit); err != nil {
return nil, 0, echo.NewHTTPError(http.StatusInternalServerError,
c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.subscribers}", "error", pqErrMsg(err)))
}
@ -196,31 +219,7 @@ func (c *Core) GetSubscriberProfileForExport(id int, uuid string) (models.Subscr
// on the given criteria in an exportable form. The iterator function returned can be called
// repeatedly until there are nil subscribers. It's an iterator because exports can be extremely
// large and may have to be fetched in batches from the DB and streamed somewhere.
func (c *Core) ExportSubscribers(query string, subIDs, listIDs []int, subStatus string, batchSize int) (func() ([]models.SubscriberExport, error), error) {
// There's an arbitrary query condition.
cond := ""
if query != "" {
cond = " AND " + query
}
stmt := strings.ReplaceAll(c.q.QuerySubscribersForExport, "%query%", cond)
// Verify that the arbitrary SQL search expression is read only.
if cond != "" {
tx, err := c.db.Unsafe().BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
c.log.Printf("error preparing subscriber query: %v", err)
return nil, echo.NewHTTPError(http.StatusBadRequest,
c.i18n.Ts("subscribers.errorPreparingQuery", "error", pqErrMsg(err)))
}
defer tx.Rollback()
if _, err := tx.Query(stmt, nil, 0, nil, subStatus, 1); err != nil {
return nil, echo.NewHTTPError(http.StatusBadRequest,
c.i18n.Ts("subscribers.errorPreparingQuery", "error", pqErrMsg(err)))
}
}
func (c *Core) ExportSubscribers(searchStr, query string, subIDs, listIDs []int, subStatus string, batchSize int) (func() ([]models.SubscriberExport, error), error) {
if subIDs == nil {
subIDs = []int{}
}
@ -228,6 +227,21 @@ func (c *Core) ExportSubscribers(query string, subIDs, listIDs []int, subStatus
listIDs = []int{}
}
// There's an arbitrary query condition.
cond := "TRUE"
if query != "" {
cond = query
}
stmt := strings.ReplaceAll(c.q.QuerySubscribersForExport, "%query%", cond)
// Create a readonly transaction that just does COUNT() to obtain the count of results
// and to ensure that the arbitrary query is indeed readonly.
if _, err := c.getSubscriberCount(searchStr, cond, subStatus, listIDs); err != nil {
c.log.Printf("error getting subscriber count: %v", err)
return nil, err
}
// Prepare the actual query statement.
tx, err := c.db.Preparex(stmt)
if err != nil {
@ -239,7 +253,7 @@ func (c *Core) ExportSubscribers(query string, subIDs, listIDs []int, subStatus
id := 0
return func() ([]models.SubscriberExport, error) {
var out []models.SubscriberExport
if err := tx.Select(&out, pq.Array(listIDs), id, pq.Array(subIDs), subStatus, batchSize); err != nil {
if err := tx.Select(&out, pq.Array(listIDs), id, pq.Array(subIDs), subStatus, searchStr, batchSize); err != nil {
c.log.Printf("error exporting subscribers by query: %v", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError,
c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.subscribers}", "error", pqErrMsg(err)))
@ -414,8 +428,8 @@ func (c *Core) BlocklistSubscribers(subIDs []int) error {
}
// BlocklistSubscribersByQuery blocklists the given list of subscribers.
func (c *Core) BlocklistSubscribersByQuery(query string, listIDs []int, subStatus string) error {
if err := c.q.ExecSubQueryTpl(sanitizeSQLExp(query), c.q.BlocklistSubscribersByQuery, listIDs, c.db, subStatus); err != nil {
func (c *Core) BlocklistSubscribersByQuery(searchStr, queryExp string, listIDs []int, subStatus string) error {
if err := c.q.ExecSubQueryTpl(searchStr, sanitizeSQLExp(queryExp), c.q.BlocklistSubscribersByQuery, listIDs, c.db, subStatus); err != nil {
c.log.Printf("error blocklisting subscribers: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError,
c.i18n.Ts("subscribers.errorBlocklisting", "error", pqErrMsg(err)))
@ -443,8 +457,8 @@ func (c *Core) DeleteSubscribers(subIDs []int, subUUIDs []string) error {
}
// DeleteSubscribersByQuery deletes subscribers by a given arbitrary query expression.
func (c *Core) DeleteSubscribersByQuery(query string, listIDs []int, subStatus string) error {
err := c.q.ExecSubQueryTpl(sanitizeSQLExp(query), c.q.DeleteSubscribersByQuery, listIDs, c.db, subStatus)
func (c *Core) DeleteSubscribersByQuery(searchStr, queryExp string, listIDs []int, subStatus string) error {
err := c.q.ExecSubQueryTpl(searchStr, sanitizeSQLExp(queryExp), c.q.DeleteSubscribersByQuery, listIDs, c.db, subStatus)
if err != nil {
c.log.Printf("error deleting subscribers: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError,
@ -522,9 +536,9 @@ func (c *Core) DeleteBlocklistedSubscribers() (int, error) {
return int(n), nil
}
func (c *Core) getSubscriberCount(cond, subStatus string, listIDs []int) (int, error) {
func (c *Core) getSubscriberCount(searchStr, queryExp, subStatus string, listIDs []int) (int, error) {
// If there's no condition, it's a "get all" call which can probably be optionally pulled from cache.
if cond == "" {
if queryExp == "" {
_ = c.refreshCache(matListSubStats, false)
total := 0
@ -538,7 +552,7 @@ func (c *Core) getSubscriberCount(cond, subStatus string, listIDs []int) (int, e
// Create a readonly transaction that just does COUNT() to obtain the count of results
// and to ensure that the arbitrary query is indeed readonly.
stmt := fmt.Sprintf(c.q.QuerySubscribersCount, cond)
stmt := strings.ReplaceAll(c.q.QuerySubscribersCount, "%query%", queryExp)
tx, err := c.db.BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
c.log.Printf("error preparing subscriber query: %v", err)
@ -548,10 +562,80 @@ func (c *Core) getSubscriberCount(cond, subStatus string, listIDs []int) (int, e
// Execute the readonly query and get the count of results.
total := 0
if err := tx.Get(&total, stmt, pq.Array(listIDs), subStatus); err != nil {
if err := tx.Get(&total, stmt, pq.Array(listIDs), subStatus, searchStr); err != nil {
return 0, echo.NewHTTPError(http.StatusInternalServerError,
c.i18n.Ts("globals.messages.errorFetching", "name", "{globals.terms.subscribers}", "error", pqErrMsg(err)))
}
return total, nil
}
// validateQueryTables checks if the query accesses only allowed tables.
func validateQueryTables(db *sqlx.DB, query string, allowedTables map[string]struct{}) error {
// Get the EXPLAIN (FORMAT JSON) output.
tx, err := db.BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
return err
}
defer tx.Rollback()
var plan string
if err = tx.QueryRow("EXPLAIN (FORMAT JSON) "+query, nil, models.SubscriberStatusEnabled, "", 0, 10).Scan(&plan); err != nil {
return err
}
// Extract all relation names from the JSON plan.
tables, err := getTablesFromQueryPlan(plan)
if err != nil {
return fmt.Errorf("error getting tables from query: %v", err)
}
// Validate against allowed tables.
for _, table := range tables {
if _, ok := allowedTables[table]; !ok {
return fmt.Errorf("table '%s' is not allowed", table)
}
}
return nil
}
// getTablesFromQueryPlan parses the EXPLAIN JSON to find all "Relation Name" entries.
func getTablesFromQueryPlan(explainJSON string) ([]string, error) {
var plans []map[string]any
if err := json.Unmarshal([]byte(explainJSON), &plans); err != nil {
return nil, err
}
// Collect table names in `tables` recursively.
tables := make(map[string]struct{})
for _, plan := range plans {
traverseQueryPlan(plan, tables)
}
result := make([]string, 0, len(tables))
for table := range tables {
result = append(result, table)
}
return result, nil
}
func traverseQueryPlan(node map[string]any, tables map[string]struct{}) {
if relName, ok := node["Relation Name"].(string); ok {
tables[relName] = struct{}{}
}
// Recursively check nested plans (e.g., subqueries, CTEs).
for _, v := range node {
switch v := v.(type) {
case map[string]any:
traverseQueryPlan(v, tables)
case []any:
for _, item := range v {
if m, ok := item.(map[string]any); ok {
traverseQueryPlan(m, tables)
}
}
}
}
}

View file

@ -35,12 +35,12 @@ func (c *Core) AddSubscriptions(subIDs, listIDs []int, status string) error {
// AddSubscriptionsByQuery adds list subscriptions to subscribers by a given arbitrary query expression.
// sourceListIDs is the list of list IDs to filter the subscriber query with.
func (c *Core) AddSubscriptionsByQuery(query string, sourceListIDs, targetListIDs []int, status string, subStatus string) error {
func (c *Core) AddSubscriptionsByQuery(searchStr, queryExp string, sourceListIDs, targetListIDs []int, status string, subStatus string) error {
if sourceListIDs == nil {
sourceListIDs = []int{}
}
err := c.q.ExecSubQueryTpl(sanitizeSQLExp(query), c.q.AddSubscribersToListsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs), status)
err := c.q.ExecSubQueryTpl(searchStr, queryExp, c.q.AddSubscribersToListsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs), status)
if err != nil {
c.log.Printf("error adding subscriptions by query: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError,
@ -64,12 +64,12 @@ func (c *Core) DeleteSubscriptions(subIDs, listIDs []int) error {
// DeleteSubscriptionsByQuery deletes list subscriptions from subscribers by a given arbitrary query expression.
// sourceListIDs is the list of list IDs to filter the subscriber query with.
func (c *Core) DeleteSubscriptionsByQuery(query string, sourceListIDs, targetListIDs []int, subStatus string) error {
func (c *Core) DeleteSubscriptionsByQuery(searchStr, queryExp string, sourceListIDs, targetListIDs []int, subStatus string) error {
if sourceListIDs == nil {
sourceListIDs = []int{}
}
err := c.q.ExecSubQueryTpl(sanitizeSQLExp(query), c.q.DeleteSubscriptionsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs))
err := c.q.ExecSubQueryTpl(searchStr, queryExp, c.q.DeleteSubscriptionsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs))
if err != nil {
c.log.Printf("error deleting subscriptions by query: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError,
@ -92,12 +92,12 @@ func (c *Core) UnsubscribeLists(subIDs, listIDs []int, listUUIDs []string) error
// UnsubscribeListsByQuery sets list subscriptions to 'unsubscribed' by a given arbitrary query expression.
// sourceListIDs is the list of list IDs to filter the subscriber query with.
func (c *Core) UnsubscribeListsByQuery(query string, sourceListIDs, targetListIDs []int, subStatus string) error {
func (c *Core) UnsubscribeListsByQuery(searchStr, queryExp string, sourceListIDs, targetListIDs []int, subStatus string) error {
if sourceListIDs == nil {
sourceListIDs = []int{}
}
err := c.q.ExecSubQueryTpl(sanitizeSQLExp(query), c.q.UnsubscribeSubscribersFromListsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs))
err := c.q.ExecSubQueryTpl(searchStr, queryExp, c.q.UnsubscribeSubscribersFromListsByQuery, sourceListIDs, c.db, subStatus, pq.Array(targetListIDs))
if err != nil {
c.log.Printf("error unsubscribing from lists by query: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError,

View file

@ -123,6 +123,7 @@ var regTplFuncs = []regTplFunc{
type PageResults struct {
Results any `json:"results"`
Search string `json:"search"`
Query string `json:"query"`
Total int `json:"total"`
PerPage int `json:"per_page"`

View file

@ -3,7 +3,7 @@ package models
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
@ -131,35 +131,39 @@ type Queries struct {
DeleteListPermission *sqlx.Stmt `query:"delete-list-permission"`
}
// CompileSubscriberQueryTpl takes an arbitrary WHERE expressions
// compileSubscriberQueryTpl takes an arbitrary WHERE expressions
// to filter subscribers from the subscribers table and prepares a query
// out of it using the raw `query-subscribers-template` query template.
// While doing this, a readonly transaction is created and the query is
// dry run on it to ensure that it is indeed readonly.
func (q *Queries) CompileSubscriberQueryTpl(exp string, db *sqlx.DB, subStatus string) (string, error) {
func (q *Queries) compileSubscriberQueryTpl(searchStr, queryExp string, db *sqlx.DB, subStatus string) (string, error) {
tx, err := db.BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true})
if err != nil {
return "", err
}
defer tx.Rollback()
// Perform the dry run.
if exp != "" {
exp = " AND " + exp
// There's an arbitrary query condition.
cond := "TRUE"
if queryExp != "" {
cond = queryExp
}
stmt := fmt.Sprintf(q.QuerySubscribersTpl, exp)
if _, err := tx.Exec(stmt, true, pq.Int64Array{}, subStatus); err != nil {
// Perform the dry run.
stmt := strings.ReplaceAll(q.QuerySubscribersTpl, "%query%", cond)
if _, err := tx.Exec(stmt, true, pq.Int64Array{}, subStatus, searchStr); err != nil {
return "", err
}
return stmt, nil
}
// compileSubscriberQueryTpl takes an arbitrary WHERE expressions and a subscriber
// query template that depends on the filter (eg: delete by query, blocklist by query etc.)
// combines and executes them.
func (q *Queries) ExecSubQueryTpl(exp, tpl string, listIDs []int, db *sqlx.DB, subStatus string, args ...any) error {
func (q *Queries) ExecSubQueryTpl(searchStr, queryExp, baseQueryTpl string, listIDs []int, db *sqlx.DB, subStatus string, args ...any) error {
// Perform a dry run.
filterExp, err := q.CompileSubscriberQueryTpl(exp, db, subStatus)
filterExp, err := q.compileSubscriberQueryTpl(searchStr, queryExp, db, subStatus)
if err != nil {
return err
}
@ -168,9 +172,14 @@ func (q *Queries) ExecSubQueryTpl(exp, tpl string, listIDs []int, db *sqlx.DB, s
listIDs = []int{}
}
// Insert the subscriber filter query into the target query.
stmt := strings.ReplaceAll(baseQueryTpl, "%query%", filterExp)
// First argument is the boolean indicating if the query is a dry run.
a := append([]any{false, pq.Array(listIDs), subStatus}, args...)
if _, err := db.Exec(fmt.Sprintf(tpl, filterExp), a...); err != nil {
a := append([]any{false, pq.Array(listIDs), subStatus, searchStr}, args...)
// Execute the query on the DB.
if _, err := db.Exec(stmt, a...); err != nil {
return err
}
return nil

View file

@ -309,7 +309,6 @@ SELECT (SELECT email FROM prof) as email,
-- there's a COUNT() OVER() that still returns the total result count
-- for pagination in the frontend, albeit being a field that'll repeat
-- with every resultant row.
-- %s = arbitrary expression, %s = order by field, %s = order direction
SELECT subscribers.* FROM subscribers
LEFT JOIN subscriber_lists
ON (
@ -319,8 +318,9 @@ SELECT subscribers.* FROM subscribers
AND ($2 = '' OR subscriber_lists.status = $2::subscription_status)
)
WHERE (CARDINALITY($1) = 0 OR subscriber_lists.list_id = ANY($1::INT[]))
%query%
ORDER BY %order% OFFSET $3 LIMIT (CASE WHEN $4 < 1 THEN NULL ELSE $4 END);
AND (CASE WHEN $3 != '' THEN name ~* $3 OR email ~* $3 ELSE TRUE END)
AND %query%
ORDER BY %order% OFFSET $4 LIMIT (CASE WHEN $5 < 1 THEN NULL ELSE $5 END);
-- name: query-subscribers-count
-- Replica of query-subscribers for obtaining the results count.
@ -332,7 +332,9 @@ SELECT COUNT(*) AS total FROM subscribers
AND subscriber_lists.subscriber_id = subscribers.id
AND ($2 = '' OR subscriber_lists.status = $2::subscription_status)
)
WHERE (CARDINALITY($1) = 0 OR subscriber_lists.list_id = ANY($1::INT[])) %s;
WHERE (CARDINALITY($1) = 0 OR subscriber_lists.list_id = ANY($1::INT[]))
AND (CASE WHEN $3 != '' THEN name ~* $3 OR email ~* $3 ELSE TRUE END)
AND %query%;
-- name: query-subscribers-count-all
-- Cached query for getting the "all" subscriber count without arbitrary conditions.
@ -344,7 +346,6 @@ SELECT COALESCE(SUM(subscriber_count), 0) AS total FROM mat_list_subscriber_stat
-- raw: true
-- Unprepared statement for issuring arbitrary WHERE conditions for
-- searching subscribers to do bulk CSV export.
-- %s = arbitrary expression
SELECT subscribers.id,
subscribers.uuid,
subscribers.email,
@ -363,8 +364,9 @@ SELECT subscribers.id,
)
WHERE subscriber_lists.list_id = ALL($1::INT[]) AND id > $2
AND (CASE WHEN CARDINALITY($3::INT[]) > 0 THEN id=ANY($3) ELSE true END)
%query%
ORDER BY subscribers.id ASC LIMIT (CASE WHEN $5 < 1 THEN NULL ELSE $5 END);
AND (CASE WHEN $5 != '' THEN name ~* $5 OR email ~* $5 ELSE TRUE END)
AND %query%
ORDER BY subscribers.id ASC LIMIT (CASE WHEN $6 < 1 THEN NULL ELSE $6 END);
-- name: query-subscribers-template
-- raw: true
@ -374,7 +376,7 @@ SELECT subscribers.id,
--
-- All queries that embed this query should expect
-- $1=true/false (dry-run or not) and $2=[]INT (option list IDs).
-- That is, their positional arguments should start from $3.
-- That is, their positional arguments should start from $4.
SELECT subscribers.id FROM subscribers
LEFT JOIN subscriber_lists
ON (
@ -383,17 +385,19 @@ ON (
AND subscriber_lists.subscriber_id = subscribers.id
AND ($3 = '' OR subscriber_lists.status = $3::subscription_status)
)
WHERE subscriber_lists.list_id = ALL($2::INT[]) %s
WHERE subscriber_lists.list_id = ALL($2::INT[])
AND (CASE WHEN $4 != '' THEN name ~* $4 OR email ~* $4 ELSE TRUE END)
AND %query%
LIMIT (CASE WHEN $1 THEN 1 END)
-- name: delete-subscribers-by-query
-- raw: true
WITH subs AS (%s)
WITH subs AS (%query%)
DELETE FROM subscribers WHERE id=ANY(SELECT id FROM subs);
-- name: blocklist-subscribers-by-query
-- raw: true
WITH subs AS (%s),
WITH subs AS (%query%),
b AS (
UPDATE subscribers SET status='blocklisted', updated_at=NOW()
WHERE id = ANY(SELECT id FROM subs)
@ -403,22 +407,22 @@ UPDATE subscriber_lists SET status='unsubscribed', updated_at=NOW()
-- name: add-subscribers-to-lists-by-query
-- raw: true
WITH subs AS (%s)
WITH subs AS (%query%)
INSERT INTO subscriber_lists (subscriber_id, list_id, status)
(SELECT a, b, (CASE WHEN $5 != '' THEN $5::subscription_status ELSE 'unconfirmed' END) FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($4::INT[]) b)
(SELECT a, b, (CASE WHEN $6 != '' THEN $6::subscription_status ELSE 'unconfirmed' END) FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($5::INT[]) b)
ON CONFLICT (subscriber_id, list_id) DO NOTHING;
-- name: delete-subscriptions-by-query
-- raw: true
WITH subs AS (%s)
WITH subs AS (%query%)
DELETE FROM subscriber_lists
WHERE (subscriber_id, list_id) = ANY(SELECT a, b FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($4::INT[]) b);
WHERE (subscriber_id, list_id) = ANY(SELECT a, b FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($5::INT[]) b);
-- name: unsubscribe-subscribers-from-lists-by-query
-- raw: true
WITH subs AS (%s)
WITH subs AS (%query%)
UPDATE subscriber_lists SET status='unsubscribed', updated_at=NOW()
WHERE (subscriber_id, list_id) = ANY(SELECT a, b FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($4::INT[]) b);
WHERE (subscriber_id, list_id) = ANY(SELECT a, b FROM UNNEST(ARRAY(SELECT id FROM subs)) a, UNNEST($5::INT[]) b);
-- lists