feat: support more filter factors

This commit is contained in:
johnnyjoy 2025-07-22 19:18:08 +08:00
parent 1a3fc4d874
commit b55904a428
7 changed files with 265 additions and 31 deletions

View file

@ -21,6 +21,9 @@ var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
// Current timestamp function.
cel.Function("now",
cel.Overload("now",

View file

@ -97,7 +97,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if err != nil {
return err
}
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetExprValue(v.CallExpr.Args[1])
@ -176,6 +176,44 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
return err
}
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueBool, ok := value.(bool)
if !ok {
return errors.Errorf("invalid boolean value for %s", identifier)
}
// Map identifier to JSON path
var jsonPath string
switch identifier {
case "has_link":
jsonPath = "$.property.hasLink"
case "has_code":
jsonPath = "$.property.hasCode"
case "has_incomplete_tasks":
jsonPath = "$.property.hasIncompleteTasks"
}
// Use JSON_EXTRACT for boolean comparison like has_task_list
var sqlTemplate string
if operator == "=" {
if valueBool {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('true' AS JSON)", jsonPath)
} else {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('false' AS JSON)", jsonPath)
}
} else { // operator == "!="
if valueBool {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('true' AS JSON)", jsonPath)
} else {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('false' AS JSON)", jsonPath)
}
}
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
return err
}
}
case "@in":
if len(v.CallExpr.Args) != 2 {
@ -267,7 +305,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
}
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
identifier := v.IdentExpr.GetName()
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier %s", identifier)
}
if identifier == "pinned" {
@ -279,6 +317,21 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
return err
}
} else if identifier == "has_link" {
// Handle has_link as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)"); err != nil {
return err
}
} else if identifier == "has_code" {
// Handle has_code as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)"); err != nil {
return err
}
} else if identifier == "has_incomplete_tasks" {
// Handle has_incomplete_tasks as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)"); err != nil {
return err
}
}
}
return nil

View file

@ -115,6 +115,36 @@ func TestConvertExprToSQL(t *testing.T) {
want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_code == false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('false' AS JSON)",
args: []any{},
},
{
filter: `has_incomplete_tasks != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != CAST('false' AS JSON)",
args: []any{},
},
{
filter: `has_link`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_code`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)",
args: []any{},
},
}
for _, tt := range tests {

View file

@ -103,7 +103,7 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.
if err != nil {
return paramIndex, err
}
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetExprValue(v.CallExpr.Args[1])
@ -184,6 +184,47 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.
}
ctx.Args = append(ctx.Args, valueBool)
return paramIndex + 1, nil
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
if operator != "=" && operator != "!=" {
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueBool, ok := value.(bool)
if !ok {
return paramIndex, errors.Errorf("invalid boolean value for %s", identifier)
}
// Map identifier to JSON path
var jsonPath string
switch identifier {
case "has_link":
jsonPath = "$.property.hasLink"
case "has_code":
jsonPath = "$.property.hasCode"
case "has_incomplete_tasks":
jsonPath = "$.property.hasIncompleteTasks"
}
// Use JSON path for boolean comparison with PostgreSQL parameter placeholder
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
var sqlTemplate string
if operator == "=" {
if valueBool {
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
} else {
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
}
} else { // operator == "!="
if valueBool {
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
} else {
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
}
}
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
return paramIndex, err
}
ctx.Args = append(ctx.Args, valueBool)
return paramIndex + 1, nil
}
case "@in":
if len(v.CallExpr.Args) != 2 {
@ -288,7 +329,7 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.
}
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
identifier := v.IdentExpr.GetName()
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return paramIndex, errors.Errorf("invalid identifier %s", identifier)
}
if identifier == "pinned" {
@ -300,6 +341,21 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
return paramIndex, err
}
} else if identifier == "has_link" {
// Handle has_link as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasLink')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
return paramIndex, err
}
} else if identifier == "has_code" {
// Handle has_code as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasCode')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
return paramIndex, err
}
} else if identifier == "has_incomplete_tasks" {
// Handle has_incomplete_tasks as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasIncompleteTasks')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
return paramIndex, err
}
}
}
return paramIndex, nil

View file

@ -9,7 +9,7 @@ import (
"github.com/usememos/memos/plugin/filter"
)
func TestRestoreExprToSQL(t *testing.T) {
func TestConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
@ -22,7 +22,7 @@ func TestRestoreExprToSQL(t *testing.T) {
},
{
filter: `!(tag in ["tag1", "tag2"])`,
want: `NOT ((memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2)))`,
want: "NOT ((memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2)))",
args: []any{"tag1", "tag2"},
},
{
@ -115,6 +115,36 @@ func TestRestoreExprToSQL(t *testing.T) {
want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "(memo->'payload'->'property'->>'hasLink')::boolean = $1",
args: []any{true},
},
{
filter: `has_code == false`,
want: "(memo->'payload'->'property'->>'hasCode')::boolean = $1",
args: []any{false},
},
{
filter: `has_incomplete_tasks != false`,
want: "(memo->'payload'->'property'->>'hasIncompleteTasks')::boolean != $1",
args: []any{false},
},
{
filter: `has_link`,
want: "(memo->'payload'->'property'->>'hasLink')::boolean = true",
args: []any{},
},
{
filter: `has_code`,
want: "(memo->'payload'->'property'->>'hasCode')::boolean = true",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "(memo->'payload'->'property'->>'hasIncompleteTasks')::boolean = true",
args: []any{},
},
}
for _, tt := range tests {
@ -127,4 +157,4 @@ func TestRestoreExprToSQL(t *testing.T) {
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}
}

View file

@ -97,7 +97,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if err != nil {
return err
}
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
}
value, err := filter.GetExprValue(v.CallExpr.Args[1])
@ -176,6 +176,44 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
return err
}
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
}
valueBool, ok := value.(bool)
if !ok {
return errors.Errorf("invalid boolean value for %s", identifier)
}
// Map identifier to JSON path
var jsonPath string
switch identifier {
case "has_link":
jsonPath = "$.property.hasLink"
case "has_code":
jsonPath = "$.property.hasCode"
case "has_incomplete_tasks":
jsonPath = "$.property.hasIncompleteTasks"
}
// Use JSON_EXTRACT for boolean comparison like has_task_list
var sqlTemplate string
if operator == "=" {
if valueBool {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = JSON('true')", jsonPath)
} else {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = JSON('false')", jsonPath)
}
} else { // operator == "!="
if valueBool {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != JSON('true')", jsonPath)
} else {
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != JSON('false')", jsonPath)
}
}
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
return err
}
}
case "@in":
if len(v.CallExpr.Args) != 2 {
@ -267,7 +305,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
}
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
identifier := v.IdentExpr.GetName()
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier %s", identifier)
}
if identifier == "pinned" {
@ -279,6 +317,21 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr)
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
return err
}
} else if identifier == "has_link" {
// Handle has_link as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')"); err != nil {
return err
}
} else if identifier == "has_code" {
// Handle has_code as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('true')"); err != nil {
return err
}
} else if identifier == "has_incomplete_tasks" {
// Handle has_incomplete_tasks as a standalone boolean identifier
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = JSON('true')"); err != nil {
return err
}
}
}
return nil

View file

@ -25,11 +25,6 @@ func TestConvertExprToSQL(t *testing.T) {
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`%"tag1"%`, `%"tag2"%`},
},
{
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`},
},
{
filter: `content.contains("memos")`,
want: "`memo`.`content` LIKE ?",
@ -60,16 +55,6 @@ func TestConvertExprToSQL(t *testing.T) {
want: "`memo`.`pinned` IS TRUE",
args: []any{},
},
{
filter: `!pinned`,
want: "NOT (`memo`.`pinned` IS TRUE)",
args: []any{},
},
{
filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`,
want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))",
args: []any{int64(101), "PUBLIC", "PRIVATE"},
},
{
filter: `has_task_list`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
@ -130,22 +115,46 @@ func TestConvertExprToSQL(t *testing.T) {
want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?",
args: []any{int64(2)},
},
{
filter: `has_link == true`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')",
args: []any{},
},
{
filter: `has_code == false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('false')",
args: []any{},
},
{
filter: `has_incomplete_tasks != false`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != JSON('false')",
args: []any{},
},
{
filter: `has_link`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')",
args: []any{},
},
{
filter: `has_code`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('true')",
args: []any{},
},
{
filter: `has_incomplete_tasks`,
want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = JSON('true')",
args: []any{},
},
}
for _, tt := range tests {
db := &DB{}
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
if err != nil {
t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err)
}
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err)
}
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}
}