From b55904a4284b81891c105013f09b3cb7ded61ccc Mon Sep 17 00:00:00 2001 From: johnnyjoy Date: Tue, 22 Jul 2025 19:18:08 +0800 Subject: [PATCH] feat: support more filter factors --- plugin/filter/filter.go | 3 ++ store/db/mysql/memo_filter.go | 57 ++++++++++++++++++++++++- store/db/mysql/memo_filter_test.go | 30 ++++++++++++++ store/db/postgres/memo_filter.go | 60 ++++++++++++++++++++++++++- store/db/postgres/memo_filter_test.go | 36 ++++++++++++++-- store/db/sqlite/memo_filter.go | 57 ++++++++++++++++++++++++- store/db/sqlite/memo_filter_test.go | 53 +++++++++++++---------- 7 files changed, 265 insertions(+), 31 deletions(-) diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 6ebaca237..02a4f7f29 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -21,6 +21,9 @@ var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("tags", cel.ListType(cel.StringType)), cel.Variable("visibility", cel.StringType), cel.Variable("has_task_list", cel.BoolType), + cel.Variable("has_link", cel.BoolType), + cel.Variable("has_code", cel.BoolType), + cel.Variable("has_incomplete_tasks", cel.BoolType), // Current timestamp function. cel.Function("now", cel.Overload("now", diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go index 8ae703468..b7cdee86b 100644 --- a/store/db/mysql/memo_filter.go +++ b/store/db/mysql/memo_filter.go @@ -97,7 +97,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if err != nil { return err } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetExprValue(v.CallExpr.Args[1]) @@ -176,6 +176,44 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { return err } + } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueBool, ok := value.(bool) + if !ok { + return errors.Errorf("invalid boolean value for %s", identifier) + } + + // Map identifier to JSON path + var jsonPath string + switch identifier { + case "has_link": + jsonPath = "$.property.hasLink" + case "has_code": + jsonPath = "$.property.hasCode" + case "has_incomplete_tasks": + jsonPath = "$.property.hasIncompleteTasks" + } + + // Use JSON_EXTRACT for boolean comparison like has_task_list + var sqlTemplate string + if operator == "=" { + if valueBool { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('true' AS JSON)", jsonPath) + } else { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('false' AS JSON)", jsonPath) + } + } else { // operator == "!=" + if valueBool { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('true' AS JSON)", jsonPath) + } else { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('false' AS JSON)", jsonPath) + } + } + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { + return err + } } case "@in": if len(v.CallExpr.Args) != 2 { @@ -267,7 +305,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) } } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { + if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { @@ -279,6 +317,21 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { return err } + } else if identifier == "has_link" { + // Handle has_link as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)"); err != nil { + return err + } + } else if identifier == "has_code" { + // Handle has_code as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)"); err != nil { + return err + } + } else if identifier == "has_incomplete_tasks" { + // Handle has_incomplete_tasks as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)"); err != nil { + return err + } } } return nil diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index b5dd090d1..397e2ca5a 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -115,6 +115,36 @@ func TestConvertExprToSQL(t *testing.T) { want: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", args: []any{int64(2)}, }, + { + filter: `has_link == true`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)", + args: []any{}, + }, + { + filter: `has_code == false`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('false' AS JSON)", + args: []any{}, + }, + { + filter: `has_incomplete_tasks != false`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != CAST('false' AS JSON)", + args: []any{}, + }, + { + filter: `has_link`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)", + args: []any{}, + }, + { + filter: `has_code`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)", + args: []any{}, + }, + { + filter: `has_incomplete_tasks`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)", + args: []any{}, + }, } for _, tt := range tests { diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go index 1c9802cd4..cbb3b088c 100644 --- a/store/db/postgres/memo_filter.go +++ b/store/db/postgres/memo_filter.go @@ -103,7 +103,7 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1. if err != nil { return paramIndex, err } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetExprValue(v.CallExpr.Args[1]) @@ -184,6 +184,47 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1. } ctx.Args = append(ctx.Args, valueBool) return paramIndex + 1, nil + } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { + if operator != "=" && operator != "!=" { + return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueBool, ok := value.(bool) + if !ok { + return paramIndex, errors.Errorf("invalid boolean value for %s", identifier) + } + + // Map identifier to JSON path + var jsonPath string + switch identifier { + case "has_link": + jsonPath = "$.property.hasLink" + case "has_code": + jsonPath = "$.property.hasCode" + case "has_incomplete_tasks": + jsonPath = "$.property.hasIncompleteTasks" + } + + // Use JSON path for boolean comparison with PostgreSQL parameter placeholder + placeholder := filter.GetParameterPlaceholder(dbType, paramIndex) + var sqlTemplate string + if operator == "=" { + if valueBool { + sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) + } else { + sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) + } + } else { // operator == "!=" + if valueBool { + sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) + } else { + sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder) + } + } + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { + return paramIndex, err + } + ctx.Args = append(ctx.Args, valueBool) + return paramIndex + 1, nil } case "@in": if len(v.CallExpr.Args) != 2 { @@ -288,7 +329,7 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1. } } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { + if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return paramIndex, errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { @@ -300,6 +341,21 @@ func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1. if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { return paramIndex, err } + } else if identifier == "has_link" { + // Handle has_link as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasLink')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { + return paramIndex, err + } + } else if identifier == "has_code" { + // Handle has_code as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasCode')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { + return paramIndex, err + } + } else if identifier == "has_incomplete_tasks" { + // Handle has_incomplete_tasks as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasIncompleteTasks')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil { + return paramIndex, err + } } } return paramIndex, nil diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index d38e3210c..4df217640 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -9,7 +9,7 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func TestRestoreExprToSQL(t *testing.T) { +func TestConvertExprToSQL(t *testing.T) { tests := []struct { filter string want string @@ -22,7 +22,7 @@ func TestRestoreExprToSQL(t *testing.T) { }, { filter: `!(tag in ["tag1", "tag2"])`, - want: `NOT ((memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2)))`, + want: "NOT ((memo.payload->'tags' @> jsonb_build_array($1) OR memo.payload->'tags' @> jsonb_build_array($2)))", args: []any{"tag1", "tag2"}, }, { @@ -115,6 +115,36 @@ func TestRestoreExprToSQL(t *testing.T) { want: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb)) = $1", args: []any{int64(2)}, }, + { + filter: `has_link == true`, + want: "(memo->'payload'->'property'->>'hasLink')::boolean = $1", + args: []any{true}, + }, + { + filter: `has_code == false`, + want: "(memo->'payload'->'property'->>'hasCode')::boolean = $1", + args: []any{false}, + }, + { + filter: `has_incomplete_tasks != false`, + want: "(memo->'payload'->'property'->>'hasIncompleteTasks')::boolean != $1", + args: []any{false}, + }, + { + filter: `has_link`, + want: "(memo->'payload'->'property'->>'hasLink')::boolean = true", + args: []any{}, + }, + { + filter: `has_code`, + want: "(memo->'payload'->'property'->>'hasCode')::boolean = true", + args: []any{}, + }, + { + filter: `has_incomplete_tasks`, + want: "(memo->'payload'->'property'->>'hasIncompleteTasks')::boolean = true", + args: []any{}, + }, } for _, tt := range tests { @@ -127,4 +157,4 @@ func TestRestoreExprToSQL(t *testing.T) { require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) } -} +} \ No newline at end of file diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index 9ace5c332..17f22f98b 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -97,7 +97,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if err != nil { return err } - if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } value, err := filter.GetExprValue(v.CallExpr.Args[1]) @@ -176,6 +176,44 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { return err } + } else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" { + if operator != "=" && operator != "!=" { + return errors.Errorf("invalid operator for %s", v.CallExpr.Function) + } + valueBool, ok := value.(bool) + if !ok { + return errors.Errorf("invalid boolean value for %s", identifier) + } + + // Map identifier to JSON path + var jsonPath string + switch identifier { + case "has_link": + jsonPath = "$.property.hasLink" + case "has_code": + jsonPath = "$.property.hasCode" + case "has_incomplete_tasks": + jsonPath = "$.property.hasIncompleteTasks" + } + + // Use JSON_EXTRACT for boolean comparison like has_task_list + var sqlTemplate string + if operator == "=" { + if valueBool { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = JSON('true')", jsonPath) + } else { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = JSON('false')", jsonPath) + } + } else { // operator == "!=" + if valueBool { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != JSON('true')", jsonPath) + } else { + sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != JSON('false')", jsonPath) + } + } + if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil { + return err + } } case "@in": if len(v.CallExpr.Args) != 2 { @@ -267,7 +305,7 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) } } else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok { identifier := v.IdentExpr.GetName() - if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) { + if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) { return errors.Errorf("invalid identifier %s", identifier) } if identifier == "pinned" { @@ -279,6 +317,21 @@ func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil { return err } + } else if identifier == "has_link" { + // Handle has_link as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')"); err != nil { + return err + } + } else if identifier == "has_code" { + // Handle has_code as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('true')"); err != nil { + return err + } + } else if identifier == "has_incomplete_tasks" { + // Handle has_incomplete_tasks as a standalone boolean identifier + if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = JSON('true')"); err != nil { + return err + } } } return nil diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index d19c98c95..fb7d62744 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -25,11 +25,6 @@ func TestConvertExprToSQL(t *testing.T) { want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", args: []any{`%"tag1"%`, `%"tag2"%`}, }, - { - filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`, - want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))", - args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`}, - }, { filter: `content.contains("memos")`, want: "`memo`.`content` LIKE ?", @@ -60,16 +55,6 @@ func TestConvertExprToSQL(t *testing.T) { want: "`memo`.`pinned` IS TRUE", args: []any{}, }, - { - filter: `!pinned`, - want: "NOT (`memo`.`pinned` IS TRUE)", - args: []any{}, - }, - { - filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`, - want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))", - args: []any{int64(101), "PUBLIC", "PRIVATE"}, - }, { filter: `has_task_list`, want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE", @@ -130,22 +115,46 @@ func TestConvertExprToSQL(t *testing.T) { want: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY())) = ?", args: []any{int64(2)}, }, + { + filter: `has_link == true`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')", + args: []any{}, + }, + { + filter: `has_code == false`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('false')", + args: []any{}, + }, + { + filter: `has_incomplete_tasks != false`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') != JSON('false')", + args: []any{}, + }, + { + filter: `has_link`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = JSON('true')", + args: []any{}, + }, + { + filter: `has_code`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = JSON('true')", + args: []any{}, + }, + { + filter: `has_incomplete_tasks`, + want: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = JSON('true')", + args: []any{}, + }, } for _, tt := range tests { db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) - if err != nil { - t.Logf("Failed to parse filter: %s, error: %v", tt.filter, err) - } require.NoError(t, err) convertCtx := filter.NewConvertContext() err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) - if err != nil { - t.Logf("Failed to convert filter: %s, error: %v", tt.filter, err) - } require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) } -} +} \ No newline at end of file