From 778a5eb184145327da958dbdda4f2481a19edbd9 Mon Sep 17 00:00:00 2001 From: Johnny Date: Mon, 23 Jun 2025 22:38:44 +0800 Subject: [PATCH] refactor: memo filter --- plugin/filter/common_converter.go | 448 ++++++++++++++++++++++++++ plugin/filter/dialect.go | 212 ++++++++++++ plugin/filter/filter.go | 1 + plugin/filter/templates.go | 146 +++++++++ store/db/mysql/memo_filter.go | 184 +++++++---- store/db/mysql/memo_filter_test.go | 20 ++ store/db/postgres/memo_filter.go | 276 ++++++++++------ store/db/postgres/memo_filter_test.go | 20 ++ store/db/sqlite/memo_filter.go | 172 +++++++--- store/db/sqlite/memo_filter_test.go | 26 ++ 10 files changed, 1304 insertions(+), 201 deletions(-) create mode 100644 plugin/filter/common_converter.go create mode 100644 plugin/filter/dialect.go create mode 100644 plugin/filter/templates.go diff --git a/plugin/filter/common_converter.go b/plugin/filter/common_converter.go new file mode 100644 index 000000000..7c25b87c5 --- /dev/null +++ b/plugin/filter/common_converter.go @@ -0,0 +1,448 @@ +package filter + +import ( + "fmt" + "slices" + "strings" + + "github.com/pkg/errors" + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// CommonSQLConverter handles the common CEL to SQL conversion logic +type CommonSQLConverter struct { + dialect SQLDialect + paramIndex int +} + +// NewCommonSQLConverter creates a new converter with the specified dialect +func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter { + return &CommonSQLConverter{ + dialect: dialect, + paramIndex: 1, + } +} + +// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect +func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error { + if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { + switch v.CallExpr.Function { + case "_||_", "_&&_": + return c.handleLogicalOperator(ctx, v.CallExpr) + case "!_": + return c.handleNotOperator(ctx, v.CallExpr) + case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": + return c.handleComparisonOperator(ctx, v.CallExpr) + case "@in": + return c.handleInOperator(ctx, v.CallExpr) + case "contains": + return c.handleContainsOperator(ctx, v.CallExpr) + } + } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { + return c.handleIdentifier(ctx, v.IdentExpr) + } + return nil +} + +func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { + if len(callExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", callExpr.Function) + } + + if _, err := ctx.Buffer.WriteString("("); err != nil { + return err + } + + if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil { + return err + } + + operator := "AND" + if callExpr.Function == "_||_" { + operator = "OR" + } + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { + return err + } + + if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil { + return err + } + + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + + return nil +} + +func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { + if len(callExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", callExpr.Function) + } + + if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { + return err + } + + if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil { + return err + } + + if _, err := ctx.Buffer.WriteString(")"); err != nil { + return err + } + + return nil +} + +func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { + if len(callExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", callExpr.Function) + } + + // Check if the left side is a function call like size(tags) + if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { + if leftCallExpr.CallExpr.Function == "size" { + return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr) + } + } + + identifier, err := GetIdentExprName(callExpr.Args[0]) + if err != nil { + return err + } + + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { + return errors.Errorf("invalid identifier for %s", callExpr.Function) + } + + value, err := GetExprValue(callExpr.Args[1]) + if err != nil { + return err + } + + operator := c.getComparisonOperator(callExpr.Function) + + switch identifier { + case "created_ts", "updated_ts": + return c.handleTimestampComparison(ctx, identifier, operator, value) + case "visibility", "content": + return c.handleStringComparison(ctx, identifier, operator, value) + case "creator_id": + return c.handleIntComparison(ctx, identifier, operator, value) + case "has_task_list": + return c.handleBooleanComparison(ctx, identifier, operator, value) + } + + return nil +} + +func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error { + if len(sizeCall.Args) != 1 { + return errors.New("size function requires exactly one argument") + } + + identifier, err := GetIdentExprName(sizeCall.Args[0]) + if err != nil { + return err + } + + if identifier != "tags" { + return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) + } + + value, err := GetExprValue(callExpr.Args[1]) + if err != nil { + return err + } + + valueInt, ok := value.(int64) + if !ok { + return errors.New("size comparison value must be an integer") + } + + operator := c.getComparisonOperator(callExpr.Function) + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", + c.dialect.GetJSONArrayLength("$.tags"), + operator, + c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, valueInt) + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { + if len(callExpr.Args) != 2 { + return errors.Errorf("invalid number of arguments for %s", callExpr.Function) + } + + // Check if this is "element in collection" syntax + if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil { + if identifier == "tags" { + return c.handleElementInTags(ctx, callExpr.Args[0]) + } + return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier) + } + + // Original logic for "identifier in [list]" syntax + identifier, err := GetIdentExprName(callExpr.Args[0]) + if err != nil { + return err + } + + if !slices.Contains([]string{"tag", "visibility"}, identifier) { + return errors.Errorf("invalid identifier for %s", callExpr.Function) + } + + values := []any{} + for _, element := range callExpr.Args[1].GetListExpr().Elements { + value, err := GetConstValue(element) + if err != nil { + return err + } + values = append(values, value) + } + + if identifier == "tag" { + return c.handleTagInList(ctx, values) + } else if identifier == "visibility" { + return c.handleVisibilityInList(ctx, values) + } + + return nil +} + +func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error { + element, err := GetConstValue(elementExpr) + if err != nil { + return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) + } + + // Use dialect-specific JSON contains logic + sqlExpr := c.dialect.GetJSONContains("$.tags", "element") + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } + + // For SQLite, we need a different approach since it uses LIKE + if _, ok := c.dialect.(*SQLiteDialect); ok { + ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element)) + } else { + ctx.Args = append(ctx.Args, element) + } + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error { + subconditions := []string{} + args := []any{} + + for _, v := range values { + if _, ok := c.dialect.(*SQLiteDialect); ok { + subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern")) + args = append(args, fmt.Sprintf(`%%"%s"%%`, v)) + } else { + subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element")) + args = append(args, v) + } + c.paramIndex++ + } + + if len(subconditions) == 1 { + if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { + return err + } + } else { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { + return err + } + } + + ctx.Args = append(ctx.Args, args...) + return nil +} + +func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error { + placeholders := []string{} + for range values { + placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex)) + c.paramIndex++ + } + + tablePrefix := c.dialect.GetTablePrefix() + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, values...) + return nil +} + +func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error { + if len(callExpr.Args) != 1 { + return errors.Errorf("invalid number of arguments for %s", callExpr.Function) + } + + identifier, err := GetIdentExprName(callExpr.Target) + if err != nil { + return err + } + + if identifier != "content" { + return errors.Errorf("invalid identifier for %s", callExpr.Function) + } + + arg, err := GetConstValue(callExpr.Args[0]) + if err != nil { + return err + } + + tablePrefix := c.dialect.GetTablePrefix() + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error { + identifier := identExpr.GetName() + + if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { + return errors.Errorf("invalid identifier %s", identifier) + } + + if identifier == "pinned" { + tablePrefix := c.dialect.GetTablePrefix() + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil { + return err + } + } else if identifier == "has_task_list" { + if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil { + return err + } + } + + return nil +} + +func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error { + valueInt, ok := value.(int64) + if !ok { + return errors.New("invalid integer timestamp value") + } + + timestampField := c.dialect.GetTimestampComparison(field) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, valueInt) + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", field) + } + + valueStr, ok := value.(string) + if !ok { + return errors.New("invalid string value") + } + + tablePrefix := c.dialect.GetTablePrefix() + fieldName := field + if field == "visibility" { + fieldName = "`visibility`" + } else if field == "content" { + fieldName = "`content`" + } + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, valueStr) + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", field) + } + + valueInt, ok := value.(int64) + if !ok { + return errors.New("invalid int value") + } + + tablePrefix := c.dialect.GetTablePrefix() + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil { + return err + } + + ctx.Args = append(ctx.Args, valueInt) + c.paramIndex++ + + return nil +} + +func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", field) + } + + valueBool, ok := value.(bool) + if !ok { + return errors.New("invalid boolean value for has_task_list") + } + + sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool) + if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil { + return err + } + + // For dialects that need parameters (PostgreSQL) + if _, ok := c.dialect.(*PostgreSQLDialect); ok { + ctx.Args = append(ctx.Args, valueBool) + c.paramIndex++ + } + + return nil +} + +func (c *CommonSQLConverter) getComparisonOperator(function string) string { + switch function { + case "_==_": + return "=" + case "_!=_": + return "!=" + case "_<_": + return "<" + case "_>_": + return ">" + case "_<=_": + return "<=" + case "_>=_": + return ">=" + default: + return "=" + } +} diff --git a/plugin/filter/dialect.go b/plugin/filter/dialect.go new file mode 100644 index 000000000..4f25fedbf --- /dev/null +++ b/plugin/filter/dialect.go @@ -0,0 +1,212 @@ +package filter + +import ( + "fmt" + "strings" +) + +// SQLDialect defines database-specific SQL generation methods +type SQLDialect interface { + // Basic field access + GetTablePrefix() string + GetParameterPlaceholder(index int) string + + // JSON operations + GetJSONExtract(path string) string + GetJSONArrayLength(path string) string + GetJSONContains(path, element string) string + GetJSONLike(path, pattern string) string + + // Boolean operations + GetBooleanValue(value bool) interface{} + GetBooleanComparison(path string, value bool) string + GetBooleanCheck(path string) string + + // Timestamp operations + GetTimestampComparison(field string) string + GetCurrentTimestamp() string +} + +// DatabaseType represents the type of database +type DatabaseType string + +const ( + SQLite DatabaseType = "sqlite" + MySQL DatabaseType = "mysql" + PostgreSQL DatabaseType = "postgres" +) + +// GetDialect returns the appropriate dialect for the database type +func GetDialect(dbType DatabaseType) SQLDialect { + switch dbType { + case SQLite: + return &SQLiteDialect{} + case MySQL: + return &MySQLDialect{} + case PostgreSQL: + return &PostgreSQLDialect{} + default: + return &SQLiteDialect{} // default fallback + } +} + +// SQLiteDialect implements SQLDialect for SQLite +type SQLiteDialect struct{} + +func (d *SQLiteDialect) GetTablePrefix() string { + return "`memo`" +} + +func (d *SQLiteDialect) GetParameterPlaceholder(index int) string { + return "?" +} + +func (d *SQLiteDialect) GetJSONExtract(path string) string { + return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path) +} + +func (d *SQLiteDialect) GetJSONArrayLength(path string) string { + return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path)) +} + +func (d *SQLiteDialect) GetJSONContains(path, element string) string { + return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) +} + +func (d *SQLiteDialect) GetJSONLike(path, pattern string) string { + return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) +} + +func (d *SQLiteDialect) GetBooleanValue(value bool) interface{} { + if value { + return 1 + } + return 0 +} + +func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string { + return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value)) +} + +func (d *SQLiteDialect) GetBooleanCheck(path string) string { + return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path)) +} + +func (d *SQLiteDialect) GetTimestampComparison(field string) string { + return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field) +} + +func (d *SQLiteDialect) GetCurrentTimestamp() string { + return "strftime('%s', 'now')" +} + +// MySQLDialect implements SQLDialect for MySQL +type MySQLDialect struct{} + +func (d *MySQLDialect) GetTablePrefix() string { + return "`memo`" +} + +func (d *MySQLDialect) GetParameterPlaceholder(index int) string { + return "?" +} + +func (d *MySQLDialect) GetJSONExtract(path string) string { + return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path) +} + +func (d *MySQLDialect) GetJSONArrayLength(path string) string { + return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path)) +} + +func (d *MySQLDialect) GetJSONContains(path, element string) string { + return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path)) +} + +func (d *MySQLDialect) GetJSONLike(path, pattern string) string { + return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path)) +} + +func (d *MySQLDialect) GetBooleanValue(value bool) interface{} { + return value +} + +func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string { + boolStr := "false" + if value { + boolStr = "true" + } + return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr) +} + +func (d *MySQLDialect) GetBooleanCheck(path string) string { + return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path)) +} + +func (d *MySQLDialect) GetTimestampComparison(field string) string { + return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field) +} + +func (d *MySQLDialect) GetCurrentTimestamp() string { + return "UNIX_TIMESTAMP()" +} + +// PostgreSQLDialect implements SQLDialect for PostgreSQL +type PostgreSQLDialect struct{} + +func (d *PostgreSQLDialect) GetTablePrefix() string { + return "memo" +} + +func (d *PostgreSQLDialect) GetParameterPlaceholder(index int) string { + return fmt.Sprintf("$%d", index) +} + +func (d *PostgreSQLDialect) GetJSONExtract(path string) string { + // Convert $.property.hasTaskList to payload->'property'->>'hasTaskList' + parts := strings.Split(strings.TrimPrefix(path, "$."), ".") + result := fmt.Sprintf("%s.payload", d.GetTablePrefix()) + for i, part := range parts { + if i == len(parts)-1 { + result += fmt.Sprintf("->>'%s'", part) + } else { + result += fmt.Sprintf("->'%s'", part) + } + } + return result +} + +func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string { + jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) + return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath) +} + +func (d *PostgreSQLDialect) GetJSONContains(path, element string) string { + jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) + return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath) +} + +func (d *PostgreSQLDialect) GetJSONLike(path, pattern string) string { + jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1) + return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath) +} + +func (d *PostgreSQLDialect) GetBooleanValue(value bool) interface{} { + return value +} + +func (d *PostgreSQLDialect) GetBooleanComparison(path string, value bool) string { + return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path)) +} + +func (d *PostgreSQLDialect) GetBooleanCheck(path string) string { + return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path)) +} + +func (d *PostgreSQLDialect) GetTimestampComparison(field string) string { + return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field) +} + +func (d *PostgreSQLDialect) GetCurrentTimestamp() string { + return "EXTRACT(EPOCH FROM NOW())" +} diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 576b7d8a4..6ebaca237 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -18,6 +18,7 @@ var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("updated_ts", cel.IntType), cel.Variable("pinned", cel.BoolType), cel.Variable("tag", cel.StringType), + cel.Variable("tags", cel.ListType(cel.StringType)), cel.Variable("visibility", cel.StringType), cel.Variable("has_task_list", cel.BoolType), // Current timestamp function. diff --git a/plugin/filter/templates.go b/plugin/filter/templates.go new file mode 100644 index 000000000..b6dcaeaf3 --- /dev/null +++ b/plugin/filter/templates.go @@ -0,0 +1,146 @@ +package filter + +import ( + "fmt" +) + +// SQLTemplate holds database-specific SQL fragments +type SQLTemplate struct { + SQLite string + MySQL string + PostgreSQL string +} + +// TemplateDBType represents the database type for templates +type TemplateDBType string + +const ( + SQLiteTemplate TemplateDBType = "sqlite" + MySQLTemplate TemplateDBType = "mysql" + PostgreSQLTemplate TemplateDBType = "postgres" +) + +// SQLTemplates contains common SQL patterns for different databases +var SQLTemplates = map[string]SQLTemplate{ + "json_extract": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')", + PostgreSQL: "memo.payload%s", + }, + "json_array_length": { + SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))", + MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))", + PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))", + }, + "json_contains_element": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", + MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", + PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)", + }, + "json_contains_tag": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", + MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", + PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)", + }, + "boolean_true": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true", + }, + "boolean_false": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false", + }, + "boolean_not_true": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true", + }, + "boolean_not_false": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false", + }, + "boolean_compare": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?", + }, + "boolean_check": { + SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE", + MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)", + PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE", + }, + "table_prefix": { + SQLite: "`memo`", + MySQL: "`memo`", + PostgreSQL: "memo", + }, + "timestamp_field": { + SQLite: "`memo`.`%s`", + MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)", + PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)", + }, + "content_like": { + SQLite: "`memo`.`content` LIKE ?", + MySQL: "`memo`.`content` LIKE ?", + PostgreSQL: "memo.content ILIKE ?", + }, + "visibility_in": { + SQLite: "`memo`.`visibility` IN (%s)", + MySQL: "`memo`.`visibility` IN (%s)", + PostgreSQL: "memo.visibility IN (%s)", + }, +} + +// GetSQL returns the appropriate SQL for the given template and database type +func GetSQL(templateName string, dbType TemplateDBType) string { + template, exists := SQLTemplates[templateName] + if !exists { + return "" + } + + switch dbType { + case SQLiteTemplate: + return template.SQLite + case MySQLTemplate: + return template.MySQL + case PostgreSQLTemplate: + return template.PostgreSQL + default: + return template.SQLite + } +} + +// GetParameterPlaceholder returns the appropriate parameter placeholder for the database +func GetParameterPlaceholder(dbType TemplateDBType, index int) string { + switch dbType { + case PostgreSQLTemplate: + return fmt.Sprintf("$%d", index) + default: + return "?" + } +} + +// GetParameterValue returns the appropriate parameter value for the database +func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} { + switch templateName { + case "json_contains_element", "json_contains_tag": + if dbType == SQLiteTemplate { + return fmt.Sprintf(`%%"%s"%%`, value) + } + return value + default: + return value + } +} + +// FormatPlaceholders formats a list of placeholders for the given database type +func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string { + placeholders := make([]string, count) + for i := 0; i < count; i++ { + placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i) + } + return placeholders +} diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go index 1b2e225fe..d671568f5 100644 --- a/store/db/mysql/memo_filter.go +++ b/store/db/mysql/memo_filter.go @@ -12,6 +12,12 @@ import ( ) func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + return d.convertWithTemplates(ctx, expr) +} + +func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + const dbType = filter.MySQLTemplate + if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": @@ -21,7 +27,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString("("); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { return err } operator := "AND" @@ -31,7 +37,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -44,7 +50,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -54,6 +60,39 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if len(v.CallExpr.Args) != 2 { return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + // Check if the left side is a function call like size(tags) + if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { + if leftCallExpr.CallExpr.Function == "size" { + // Handle size(tags) comparison + if len(leftCallExpr.CallExpr.Args) != 1 { + return errors.New("size function requires exactly one argument") + } + identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) + if err != nil { + return err + } + if identifier != "tags" { + return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) + } + value, err := filter.GetExprValue(v.CallExpr.Args[1]) + if err != nil { + return err + } + valueInt, ok := value.(int64) + if !ok { + return errors.New("size comparison value must be an integer") + } + operator := d.getComparisonOperator(v.CallExpr.Function) + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", + filter.GetSQL("json_array_length", dbType), operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, valueInt) + return nil + } + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { return err @@ -65,38 +104,19 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - operator := "=" - switch v.CallExpr.Function { - case "_==_": - operator = "=" - case "_!=_": - operator = "!=" - case "_<_": - operator = "<" - case "_>_": - operator = ">" - case "_<=_": - operator = "<=" - case "_>=_": - operator = ">=" - } + operator := d.getComparisonOperator(v.CallExpr.Function) if identifier == "created_ts" || identifier == "updated_ts" { - timestampInt, ok := value.(int64) + valueInt, ok := value.(int64) if !ok { - return errors.New("invalid timestamp value") + return errors.New("invalid integer timestamp value") } - var factor string - if identifier == "created_ts" { - factor = "UNIX_TIMESTAMP(`memo`.`created_ts`)" - } else if identifier == "updated_ts" { - factor = "UNIX_TIMESTAMP(`memo`.`updated_ts`)" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil { return err } - ctx.Args = append(ctx.Args, timestampInt) + ctx.Args = append(ctx.Args, valueInt) } else if identifier == "visibility" || identifier == "content" { if operator != "=" && operator != "!=" { return errors.Errorf("invalid operator for %s", v.CallExpr.Function) @@ -106,13 +126,13 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.New("invalid string value") } - var factor string + var sqlTemplate string if identifier == "visibility" { - factor = "`memo`.`visibility`" + sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`" } else if identifier == "content" { - factor = "`memo`.`content`" + sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`" } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { return err } ctx.Args = append(ctx.Args, valueStr) @@ -125,11 +145,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.New("invalid int value") } - var factor string - if identifier == "creator_id" { - factor = "`memo`.`creator_id`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`" + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { return err } ctx.Args = append(ctx.Args, valueInt) @@ -141,15 +158,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if !ok { return errors.New("invalid boolean value for has_task_list") } - - // In MySQL, we can use JSON_EXTRACT to get the value and compare it to 'true' or 'false' - compareValue := "false" - if valueBool { - compareValue = "true" + // Use template for boolean comparison + var sqlTemplate string + if operator == "=" { + if valueBool { + sqlTemplate = filter.GetSQL("boolean_true", dbType) + } else { + sqlTemplate = filter.GetSQL("boolean_false", dbType) + } + } else { // operator == "!=" + if valueBool { + sqlTemplate = filter.GetSQL("boolean_not_true", dbType) + } else { + sqlTemplate = filter.GetSQL("boolean_not_false", dbType) + } } - - // MySQL uses -> as a shorthand for JSON_EXTRACT - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST('%s' AS JSON)", operator, compareValue)); err != nil { + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { return err } } @@ -157,6 +181,29 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if len(v.CallExpr.Args) != 2 { return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + + // Check if this is "element in collection" syntax + if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { + // This is "element in collection" - the second argument is the collection + if !slices.Contains([]string{"tags"}, identifier) { + return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) + } + + if identifier == "tags" { + // Handle "element" in tags + element, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) + } + if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil { + return err + } + ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) + } + return nil + } + + // Original logic for "identifier in [list]" syntax identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { return err @@ -174,27 +221,26 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err values = append(values, value) } if identifier == "tag" { - subcodition := []string{} + subconditions := []string{} args := []any{} for _, v := range values { - subcodition, args = append(subcodition, "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)"), append(args, v) + subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType)) + args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) } - if len(subcodition) == 1 { - if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { + if len(subconditions) == 1 { + if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { return err } } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { return err } } ctx.Args = append(ctx.Args, args...) } else if identifier == "visibility" { - placeholder := []string{} - for range values { - placeholder = append(placeholder, "?") - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil { + placeholders := filter.FormatPlaceholders(dbType, len(values), 1) + visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) + if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { return err } ctx.Args = append(ctx.Args, values...) @@ -214,7 +260,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil { + if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil { return err } ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) @@ -222,17 +268,37 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { identifier := v.IdentExpr.GetName() if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { - return errors.Errorf("invalid identifier for %s", identifier) + return errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString("`memo`.`pinned` IS TRUE"); err != nil { + if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil { return err } } else if identifier == "has_task_list" { - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)"); err != nil { + // Handle has_task_list as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { return err } } } return nil } + +func (d *DB) getComparisonOperator(function string) string { + switch function { + case "_==_": + return "=" + case "_!=_": + return "!=" + case "_<_": + return "<" + case "_>_": + return ">" + case "_<=_": + return "<=" + case "_>=_": + return ">=" + default: + return "=" + } +} diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index 7a89afaa4..b5dd090d1 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -95,6 +95,26 @@ func TestConvertExprToSQL(t *testing.T) { want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?", args: []any{time.Now().Unix() - 60*60*24}, }, + { + filter: `size(tags) == 0`, + want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", + args: []any{int64(0)}, + }, + { + filter: `size(tags) > 0`, + want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?", + args: []any{int64(0)}, + }, + { + filter: `"work" in tags`, + want: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)", + args: []any{"work"}, + }, + { + filter: `size(tags) == 2`, + want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", + args: []any{int64(2)}, + }, } for _, tt := range tests { diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go index a6385a1da..61f052f49 100644 --- a/store/db/postgres/memo_filter.go +++ b/store/db/postgres/memo_filter.go @@ -12,219 +12,315 @@ import ( ) func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + const dbType = filter.PostgreSQLTemplate + _, err := d.convertWithParameterIndex(ctx, expr, dbType, len(ctx.Args)+1) + return err +} + +func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.Expr, dbType filter.TemplateDBType, paramIndex int) (int, error) { + if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } if _, err := ctx.Buffer.WriteString("("); err != nil { - return err + return paramIndex, err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { - return err + newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex) + if err != nil { + return paramIndex, err } operator := "AND" if v.CallExpr.Function == "_||_" { operator = "OR" } if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { - return err + return paramIndex, err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { - return err + newParamIndex, err = d.convertWithParameterIndex(ctx, v.CallExpr.Args[1], dbType, newParamIndex) + if err != nil { + return paramIndex, err } if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err + return paramIndex, err } + return newParamIndex, nil case "!_": if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { - return err + return paramIndex, err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { - return err + newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex) + if err != nil { + return paramIndex, err } if _, err := ctx.Buffer.WriteString(")"); err != nil { - return err + return paramIndex, err } + return newParamIndex, nil case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_": if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + // Check if the left side is a function call like size(tags) + if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { + if leftCallExpr.CallExpr.Function == "size" { + // Handle size(tags) comparison + if len(leftCallExpr.CallExpr.Args) != 1 { + return paramIndex, errors.New("size function requires exactly one argument") + } + identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) + if err != nil { + return paramIndex, err + } + if identifier != "tags" { + return paramIndex, errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) + } + value, err := filter.GetExprValue(v.CallExpr.Args[1]) + if err != nil { + return paramIndex, err + } + valueInt, ok := value.(int64) + if !ok { + return paramIndex, errors.New("size comparison value must be an integer") + } + operator := d.getComparisonOperator(v.CallExpr.Function) + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", + filter.GetSQL("json_array_length", dbType), operator, + filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { + return paramIndex, err + } + ctx.Args = append(ctx.Args, valueInt) + return paramIndex + 1, nil + } + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { - return err + return paramIndex, err } if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetExprValue(v.CallExpr.Args[1]) if err != nil { - return err - } - operator := "=" - switch v.CallExpr.Function { - case "_==_": - operator = "=" - case "_!=_": - operator = "!=" - case "_<_": - operator = "<" - case "_>_": - operator = ">" - case "_<=_": - operator = "<=" - case "_>=_": - operator = ">=" + return paramIndex, err } + operator := d.getComparisonOperator(v.CallExpr.Function) if identifier == "created_ts" || identifier == "updated_ts" { - timestampInt, ok := value.(int64) + valueInt, ok := value.(int64) if !ok { - return errors.New("invalid timestamp value") + return paramIndex, errors.New("invalid integer timestamp value") } - var factor string - if identifier == "created_ts" { - factor = "EXTRACT(EPOCH FROM memo.created_ts)" - } else if identifier == "updated_ts" { - factor = "EXTRACT(EPOCH FROM memo.updated_ts)" + timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampSQL, operator, + filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { + return paramIndex, err } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { - return err - } - ctx.Args = append(ctx.Args, timestampInt) + ctx.Args = append(ctx.Args, valueInt) + return paramIndex + 1, nil } else if identifier == "visibility" || identifier == "content" { if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) } valueStr, ok := value.(string) if !ok { - return errors.New("invalid string value") + return paramIndex, errors.New("invalid string value") } - var factor string + var sqlTemplate string if identifier == "visibility" { - factor = "memo.visibility" + sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".visibility" } else if identifier == "content" { - factor = "memo.content" + sqlTemplate = filter.GetSQL("content_like", dbType) + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { + return paramIndex, err + } + ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", valueStr)) + return paramIndex + 1, nil } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { - return err + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator, + filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { + return paramIndex, err } ctx.Args = append(ctx.Args, valueStr) + return paramIndex + 1, nil } else if identifier == "creator_id" { if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) } valueInt, ok := value.(int64) if !ok { - return errors.New("invalid int value") + return paramIndex, errors.New("invalid int value") } - factor := "memo.creator_id" - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { - return err + sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".creator_id" + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator, + filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil { + return paramIndex, err } ctx.Args = append(ctx.Args, valueInt) + return paramIndex + 1, nil } else if identifier == "has_task_list" { if operator != "=" && operator != "!=" { - return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) } valueBool, ok := value.(bool) if !ok { - return errors.New("invalid boolean value for has_task_list") + return paramIndex, errors.New("invalid boolean value for has_task_list") } - - // In PostgreSQL, extract the boolean from the JSON and compare it - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(memo.payload->'property'->>'hasTaskList')::boolean %s %s", operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { - return err + // Use parameterized template for boolean comparison (PostgreSQL only) + placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) + sqlTemplate := fmt.Sprintf(filter.GetSQL("boolean_compare", dbType), operator) + sqlTemplate = strings.Replace(sqlTemplate, "?", placeholder, 1) + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { + return paramIndex, err } ctx.Args = append(ctx.Args, valueBool) + return paramIndex + 1, nil } case "@in": if len(v.CallExpr.Args) != 2 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + + // Check if this is "element in collection" syntax + if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { + // This is "element in collection" - the second argument is the collection + if !slices.Contains([]string{"tags"}, identifier) { + return paramIndex, errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) + } + + if identifier == "tags" { + // Handle "element" in tags + element, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return paramIndex, errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) + } + placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) + sql := strings.Replace(filter.GetSQL("json_contains_element", dbType), "?", placeholder, 1) + if _, err := ctx.Buffer.WriteString(sql); err != nil { + return paramIndex, err + } + ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) + return paramIndex + 1, nil + } + return paramIndex, nil + } + + // Original logic for "identifier in [list]" syntax identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { - return err + return paramIndex, err } if !slices.Contains([]string{"tag", "visibility"}, identifier) { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } values := []any{} for _, element := range v.CallExpr.Args[1].GetListExpr().Elements { value, err := filter.GetConstValue(element) if err != nil { - return err + return paramIndex, err } values = append(values, value) } if identifier == "tag" { - subcodition := []string{} + subconditions := []string{} args := []any{} + currentParamIndex := paramIndex for _, v := range values { - subcodition, args = append(subcodition, fmt.Sprintf(`memo.payload->'tags' @> jsonb_build_array(%s)`, placeholder(len(ctx.Args)+ctx.ArgsOffset+len(args)+1))), append(args, v) + // Use parameter index for each placeholder + placeholder := filter.GetParameterPlaceholder(dbType, currentParamIndex) + subcondition := strings.Replace(filter.GetSQL("json_contains_tag", dbType), "?", placeholder, 1) + subconditions = append(subconditions, subcondition) + args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) + currentParamIndex++ } - if len(subcodition) == 1 { - if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { - return err + if len(subconditions) == 1 { + if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { + return paramIndex, err } } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { - return err + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { + return paramIndex, err } } ctx.Args = append(ctx.Args, args...) + return paramIndex + len(args), nil } else if identifier == "visibility" { - placeholders := []string{} - for i := range values { - placeholders = append(placeholders, placeholder(len(ctx.Args)+ctx.ArgsOffset+i+1)) - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("memo.visibility IN (%s)", strings.Join(placeholders, ","))); err != nil { - return err + placeholders := filter.FormatPlaceholders(dbType, len(values), paramIndex) + visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) + if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { + return paramIndex, err } ctx.Args = append(ctx.Args, values...) + return paramIndex + len(values), nil } case "contains": if len(v.CallExpr.Args) != 1 { - return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } identifier, err := filter.GetIdentExprName(v.CallExpr.Target) if err != nil { - return err + return paramIndex, err } if identifier != "content" { - return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) + return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } arg, err := filter.GetConstValue(v.CallExpr.Args[0]) if err != nil { - return err + return paramIndex, err } - if _, err := ctx.Buffer.WriteString("memo.content ILIKE " + placeholder(len(ctx.Args)+ctx.ArgsOffset+1)); err != nil { - return err + placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) + sql := strings.Replace(filter.GetSQL("content_like", dbType), "?", placeholder, 1) + if _, err := ctx.Buffer.WriteString(sql); err != nil { + return paramIndex, err } ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) + return paramIndex + 1, nil } } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { identifier := v.IdentExpr.GetName() if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { - return errors.Errorf("invalid identifier %s", identifier) + return paramIndex, errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString("memo.pinned IS TRUE"); err != nil { - return err + if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".pinned IS TRUE"); err != nil { + return paramIndex, err } } else if identifier == "has_task_list" { - if _, err := ctx.Buffer.WriteString("(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE"); err != nil { - return err + // Handle has_task_list as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { + return paramIndex, err } } } - return nil + return paramIndex, nil +} + +func (d *DB) getComparisonOperator(function string) string { + switch function { + case "_==_": + return "=" + case "_!=_": + return "!=" + case "_<_": + return "<" + case "_>_": + return ">" + case "_<=_": + return "<=" + case "_>=_": + return ">=" + default: + return "=" + } } diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index a610680f5..d38e3210c 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -95,6 +95,26 @@ func TestRestoreExprToSQL(t *testing.T) { want: "EXTRACT(EPOCH FROM memo.created_ts) > $1", args: []any{time.Now().Unix() - 60*60*24}, }, + { + filter: `size(tags) == 0`, + want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1", + args: []any{int64(0)}, + }, + { + filter: `size(tags) > 0`, + want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) > $1", + args: []any{int64(0)}, + }, + { + filter: `"work" in tags`, + want: "memo.payload->'tags' @> jsonb_build_array($1)", + args: []any{"work"}, + }, + { + filter: `size(tags) == 2`, + want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1", + args: []any{int64(2)}, + }, } for _, tt := range tests { diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index b348f726a..d055706cb 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -12,6 +12,12 @@ import ( ) func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + return d.convertWithTemplates(ctx, expr) +} + +func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error { + const dbType = filter.SQLiteTemplate + if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": @@ -21,7 +27,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString("("); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { return err } operator := "AND" @@ -31,7 +37,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -44,7 +50,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { return err } - if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -54,6 +60,39 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if len(v.CallExpr.Args) != 2 { return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + // Check if the left side is a function call like size(tags) + if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok { + if leftCallExpr.CallExpr.Function == "size" { + // Handle size(tags) comparison + if len(leftCallExpr.CallExpr.Args) != 1 { + return errors.New("size function requires exactly one argument") + } + identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0]) + if err != nil { + return err + } + if identifier != "tags" { + return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier) + } + value, err := filter.GetExprValue(v.CallExpr.Args[1]) + if err != nil { + return err + } + valueInt, ok := value.(int64) + if !ok { + return errors.New("size comparison value must be an integer") + } + operator := d.getComparisonOperator(v.CallExpr.Function) + + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", + filter.GetSQL("json_array_length", dbType), operator)); err != nil { + return err + } + ctx.Args = append(ctx.Args, valueInt) + return nil + } + } + identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { return err @@ -65,21 +104,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - operator := "=" - switch v.CallExpr.Function { - case "_==_": - operator = "=" - case "_!=_": - operator = "!=" - case "_<_": - operator = "<" - case "_>_": - operator = ">" - case "_<=_": - operator = "<=" - case "_>=_": - operator = ">=" - } + operator := d.getComparisonOperator(v.CallExpr.Function) if identifier == "created_ts" || identifier == "updated_ts" { valueInt, ok := value.(int64) @@ -87,13 +112,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.New("invalid integer timestamp value") } - var factor string - if identifier == "created_ts" { - factor = "`memo`.`created_ts`" - } else if identifier == "updated_ts" { - factor = "`memo`.`updated_ts`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier) + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil { return err } ctx.Args = append(ctx.Args, valueInt) @@ -106,13 +126,13 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.New("invalid string value") } - var factor string + var sqlTemplate string if identifier == "visibility" { - factor = "`memo`.`visibility`" + sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`" } else if identifier == "content" { - factor = "`memo`.`content`" + sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`" } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { return err } ctx.Args = append(ctx.Args, valueStr) @@ -125,11 +145,8 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.New("invalid int value") } - var factor string - if identifier == "creator_id" { - factor = "`memo`.`creator_id`" - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { + sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`" + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil { return err } ctx.Args = append(ctx.Args, valueInt) @@ -141,12 +158,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if !ok { return errors.New("invalid boolean value for has_task_list") } - // In SQLite JSON boolean values are 1 for true and 0 for false - compareValue := 0 - if valueBool { - compareValue = 1 + // Use template for boolean comparison + var sqlTemplate string + if operator == "=" { + if valueBool { + sqlTemplate = filter.GetSQL("boolean_true", dbType) + } else { + sqlTemplate = filter.GetSQL("boolean_false", dbType) + } + } else { // operator == "!=" + if valueBool { + sqlTemplate = filter.GetSQL("boolean_not_true", dbType) + } else { + sqlTemplate = filter.GetSQL("boolean_not_false", dbType) + } } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s %d", operator, compareValue)); err != nil { + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { return err } } @@ -154,6 +181,29 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if len(v.CallExpr.Args) != 2 { return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function) } + + // Check if this is "element in collection" syntax + if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil { + // This is "element in collection" - the second argument is the collection + if !slices.Contains([]string{"tags"}, identifier) { + return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier) + } + + if identifier == "tags" { + // Handle "element" in tags + element, err := filter.GetConstValue(v.CallExpr.Args[0]) + if err != nil { + return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err) + } + if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil { + return err + } + ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element)) + } + return nil + } + + // Original logic for "identifier in [list]" syntax identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0]) if err != nil { return err @@ -171,27 +221,26 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err values = append(values, value) } if identifier == "tag" { - subcodition := []string{} + subconditions := []string{} args := []any{} for _, v := range values { - subcodition, args = append(subcodition, "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?"), append(args, fmt.Sprintf(`%%"%s"%%`, v)) + subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType)) + args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v)) } - if len(subcodition) == 1 { - if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil { + if len(subconditions) == 1 { + if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil { return err } } else { - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil { return err } } ctx.Args = append(ctx.Args, args...) } else if identifier == "visibility" { - placeholder := []string{} - for range values { - placeholder = append(placeholder, "?") - } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil { + placeholders := filter.FormatPlaceholders(dbType, len(values), 1) + visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ",")) + if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil { return err } ctx.Args = append(ctx.Args, values...) @@ -211,7 +260,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil { + if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil { return err } ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg)) @@ -222,15 +271,34 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err return errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { - if _, err := ctx.Buffer.WriteString("`memo`.`pinned` IS TRUE"); err != nil { + if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil { return err } } else if identifier == "has_task_list" { // Handle has_task_list as a standalone boolean identifier - if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE"); err != nil { + if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { return err } } } return nil } + +func (d *DB) getComparisonOperator(function string) string { + switch function { + case "_==_": + return "=" + case "_!=_": + return "!=" + case "_<_": + return "<" + case "_>_": + return ">" + case "_<=_": + return "<=" + case "_>=_": + return ">=" + default: + return "=" + } +} diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index c41563412..d19c98c95 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -110,14 +110,40 @@ func TestConvertExprToSQL(t *testing.T) { want: "`memo`.`created_ts` > ?", args: []any{time.Now().Unix() - 60*60*24}, }, + { + filter: `size(tags) == 0`, + want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", + args: []any{int64(0)}, + }, + { + filter: `size(tags) > 0`, + want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) > ?", + args: []any{int64(0)}, + }, + { + filter: `"work" in tags`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?", + args: []any{`%"work"%`}, + }, + { + filter: `size(tags) == 2`, + want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", + args: []any{int64(2)}, + }, } for _, tt := range tests { db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) + if err != nil { + t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err) + } require.NoError(t, err) convertCtx := filter.NewConvertContext() err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + if err != nil { + t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err) + } require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args)