1
0
Fork 0

Adding upstream version 0.28.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:57:38 +02:00
parent 88f1d47ab6
commit e28c88ef14
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
933 changed files with 194711 additions and 0 deletions

718
tools/search/filter.go Normal file
View file

@ -0,0 +1,718 @@
package search
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/ganigeorgiev/fexpr"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/store"
"github.com/spf13/cast"
)
// FilterData is a filter expression string following the `fexpr` package grammar.
//
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
//
// Example:
//
// var filter FilterData = "id = null || (name = 'test' && status = true) || (total >= {:min} && total <= {:max})"
// resolver := search.NewSimpleFieldResolver("id", "name", "status")
// expr, err := filter.BuildExpr(resolver, dbx.Params{"min": 100, "max": 200})
type FilterData string
// parsedFilterData holds a cache with previously parsed filter data expressions
// (initialized with some preallocated empty data map)
var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
// BuildExpr parses the current filter data and returns a new db WHERE expression.
//
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
//
// The parsed expressions are limited up to DefaultFilterExprLimit.
// Use [FilterData.BuildExprWithLimit] if you want to set a custom limit.
func (f FilterData) BuildExpr(
fieldResolver FieldResolver,
placeholderReplacements ...dbx.Params,
) (dbx.Expression, error) {
return f.BuildExprWithLimit(fieldResolver, DefaultFilterExprLimit, placeholderReplacements...)
}
// BuildExpr parses the current filter data and returns a new db WHERE expression.
//
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
func (f FilterData) BuildExprWithLimit(
fieldResolver FieldResolver,
maxExpressions int,
placeholderReplacements ...dbx.Params,
) (dbx.Expression, error) {
raw := string(f)
// replace the placeholder params in the raw string filter
for _, p := range placeholderReplacements {
for key, value := range p {
var replacement string
switch v := value.(type) {
case nil:
replacement = "null"
case bool, float64, float32, int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
replacement = cast.ToString(v)
default:
replacement = cast.ToString(v)
// try to json serialize as fallback
if replacement == "" {
raw, _ := json.Marshal(v)
replacement = string(raw)
}
replacement = strconv.Quote(replacement)
}
raw = strings.ReplaceAll(raw, "{:"+key+"}", replacement)
}
}
cacheKey := raw + "/" + strconv.Itoa(maxExpressions)
if data, ok := parsedFilterData.GetOk(cacheKey); ok {
return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
}
data, err := fexpr.Parse(raw)
if err != nil {
// depending on the users demand we may allow empty expressions
// (aka. expressions consisting only of whitespaces or comments)
// but for now disallow them as it seems unnecessary
// if errors.Is(err, fexpr.ErrEmpty) {
// return dbx.NewExp("1=1"), nil
// }
return nil, err
}
// store in cache
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
parsedFilterData.SetIfLessThanLimit(cacheKey, data, 500)
return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
}
func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver, maxExpressions *int) (dbx.Expression, error) {
if len(data) == 0 {
return nil, fexpr.ErrEmpty
}
result := &concatExpr{separator: " "}
for _, group := range data {
var expr dbx.Expression
var exprErr error
switch item := group.Item.(type) {
case fexpr.Expr:
if *maxExpressions <= 0 {
return nil, ErrFilterExprLimit
}
*maxExpressions--
expr, exprErr = resolveTokenizedExpr(item, fieldResolver)
case fexpr.ExprGroup:
expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver, maxExpressions)
case []fexpr.ExprGroup:
expr, exprErr = buildParsedFilterExpr(item, fieldResolver, maxExpressions)
default:
exprErr = errors.New("unsupported expression item")
}
if exprErr != nil {
return nil, exprErr
}
if len(result.parts) > 0 {
var op string
if group.Join == fexpr.JoinOr {
op = "OR"
} else {
op = "AND"
}
result.parts = append(result.parts, &opExpr{op})
}
result.parts = append(result.parts, expr)
}
return result, nil
}
func resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
lResult, lErr := resolveToken(expr.Left, fieldResolver)
if lErr != nil || lResult.Identifier == "" {
return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr)
}
rResult, rErr := resolveToken(expr.Right, fieldResolver)
if rErr != nil || rResult.Identifier == "" {
return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr)
}
return buildResolversExpr(lResult, expr.Op, rResult)
}
func buildResolversExpr(
left *ResolverResult,
op fexpr.SignOp,
right *ResolverResult,
) (dbx.Expression, error) {
var expr dbx.Expression
switch op {
case fexpr.SignEq, fexpr.SignAnyEq:
expr = resolveEqualExpr(true, left, right)
case fexpr.SignNeq, fexpr.SignAnyNeq:
expr = resolveEqualExpr(false, left, right)
case fexpr.SignLike, fexpr.SignAnyLike:
// the right side is a column and therefor wrap it with "%" for contains like behavior
if len(right.Params) == 0 {
expr = dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
} else {
expr = dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
}
case fexpr.SignNlike, fexpr.SignAnyNlike:
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
if len(right.Params) == 0 {
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
} else {
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
}
case fexpr.SignLt, fexpr.SignAnyLt:
expr = dbx.NewExp(fmt.Sprintf("%s < %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignLte, fexpr.SignAnyLte:
expr = dbx.NewExp(fmt.Sprintf("%s <= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignGt, fexpr.SignAnyGt:
expr = dbx.NewExp(fmt.Sprintf("%s > %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignGte, fexpr.SignAnyGte:
expr = dbx.NewExp(fmt.Sprintf("%s >= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
}
if expr == nil {
return nil, fmt.Errorf("unknown expression operator %q", op)
}
// multi-match expressions
if !isAnyMatchOp(op) {
if left.MultiMatchSubQuery != nil && right.MultiMatchSubQuery != nil {
mm := &manyVsManyExpr{
left: left,
right: right,
op: op,
}
expr = dbx.Enclose(dbx.And(expr, mm))
} else if left.MultiMatchSubQuery != nil {
mm := &manyVsOneExpr{
noCoalesce: left.NoCoalesce,
subQuery: left.MultiMatchSubQuery,
op: op,
otherOperand: right,
}
expr = dbx.Enclose(dbx.And(expr, mm))
} else if right.MultiMatchSubQuery != nil {
mm := &manyVsOneExpr{
noCoalesce: right.NoCoalesce,
subQuery: right.MultiMatchSubQuery,
op: op,
otherOperand: left,
inverse: true,
}
expr = dbx.Enclose(dbx.And(expr, mm))
}
}
if left.AfterBuild != nil {
expr = left.AfterBuild(expr)
}
if right.AfterBuild != nil {
expr = right.AfterBuild(expr)
}
return expr, nil
}
var normalizedIdentifiers = map[string]string{
// if `null` field is missing, treat `null` identifier as NULL token
"null": "NULL",
// if `true` field is missing, treat `true` identifier as TRUE token
"true": "1",
// if `false` field is missing, treat `false` identifier as FALSE token
"false": "0",
}
func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) {
switch token.Type {
case fexpr.TokenIdentifier:
// check for macros
// ---
if macroFunc, ok := identifierMacros[token.Literal]; ok {
placeholder := "t" + security.PseudorandomString(8)
macroValue, err := macroFunc()
if err != nil {
return nil, err
}
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: macroValue},
}, nil
}
// custom resolver
// ---
result, err := fieldResolver.Resolve(token.Literal)
if err != nil || result.Identifier == "" {
for k, v := range normalizedIdentifiers {
if strings.EqualFold(k, token.Literal) {
return &ResolverResult{Identifier: v}, nil
}
}
return nil, err
}
return result, err
case fexpr.TokenText:
placeholder := "t" + security.PseudorandomString(8)
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: token.Literal},
}, nil
case fexpr.TokenNumber:
placeholder := "t" + security.PseudorandomString(8)
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: cast.ToFloat64(token.Literal)},
}, nil
case fexpr.TokenFunction:
fn, ok := TokenFunctions[token.Literal]
if !ok {
return nil, fmt.Errorf("unknown function %q", token.Literal)
}
args, _ := token.Meta.([]fexpr.Token)
return fn(func(argToken fexpr.Token) (*ResolverResult, error) {
return resolveToken(argToken, fieldResolver)
}, args...)
}
return nil, fmt.Errorf("unsupported token type %q", token.Type)
}
// Resolves = and != expressions in an attempt to minimize the COALESCE
// usage and to gracefully handle null vs empty string normalizations.
//
// The expression `a = "" OR a is null` tends to perform better than
// `COALESCE(a, "") = ""` since the direct match can be accomplished
// with a seek while the COALESCE will induce a table scan.
func resolveEqualExpr(equal bool, left, right *ResolverResult) dbx.Expression {
isLeftEmpty := isEmptyIdentifier(left) || (len(left.Params) == 1 && hasEmptyParamValue(left))
isRightEmpty := isEmptyIdentifier(right) || (len(right.Params) == 1 && hasEmptyParamValue(right))
equalOp := "="
nullEqualOp := "IS"
concatOp := "OR"
nullExpr := "IS NULL"
if !equal {
// always use `IS NOT` instead of `!=` because direct non-equal comparisons
// to nullable column values that are actually NULL yields to NULL instead of TRUE, eg.:
// `'example' != nullableColumn` -> NULL even if nullableColumn row value is NULL
equalOp = "IS NOT"
nullEqualOp = equalOp
concatOp = "AND"
nullExpr = "IS NOT NULL"
}
// no coalesce (eg. compare to a json field)
// a IS b
// a IS NOT b
if left.NoCoalesce || right.NoCoalesce {
return dbx.NewExp(
fmt.Sprintf("%s %s %s", left.Identifier, nullEqualOp, right.Identifier),
mergeParams(left.Params, right.Params),
)
}
// both operands are empty
if isLeftEmpty && isRightEmpty {
return dbx.NewExp(fmt.Sprintf("'' %s ''", equalOp), mergeParams(left.Params, right.Params))
}
// direct compare since at least one of the operands is known to be non-empty
// eg. a = 'example'
if isKnownNonEmptyIdentifier(left) || isKnownNonEmptyIdentifier(right) {
leftIdentifier := left.Identifier
if isLeftEmpty {
leftIdentifier = "''"
}
rightIdentifier := right.Identifier
if isRightEmpty {
rightIdentifier = "''"
}
return dbx.NewExp(
fmt.Sprintf("%s %s %s", leftIdentifier, equalOp, rightIdentifier),
mergeParams(left.Params, right.Params),
)
}
// "" = b OR b IS NULL
// "" IS NOT b AND b IS NOT NULL
if isLeftEmpty {
return dbx.NewExp(
fmt.Sprintf("('' %s %s %s %s %s)", equalOp, right.Identifier, concatOp, right.Identifier, nullExpr),
mergeParams(left.Params, right.Params),
)
}
// a = "" OR a IS NULL
// a IS NOT "" AND a IS NOT NULL
if isRightEmpty {
return dbx.NewExp(
fmt.Sprintf("(%s %s '' %s %s %s)", left.Identifier, equalOp, concatOp, left.Identifier, nullExpr),
mergeParams(left.Params, right.Params),
)
}
// fallback to a COALESCE comparison
return dbx.NewExp(
fmt.Sprintf(
"COALESCE(%s, '') %s COALESCE(%s, '')",
left.Identifier,
equalOp,
right.Identifier,
),
mergeParams(left.Params, right.Params),
)
}
func hasEmptyParamValue(result *ResolverResult) bool {
for _, p := range result.Params {
switch v := p.(type) {
case nil:
return true
case string:
if v == "" {
return true
}
}
}
return false
}
func isKnownNonEmptyIdentifier(result *ResolverResult) bool {
switch strings.ToLower(result.Identifier) {
case "1", "0", "false", `true`:
return true
}
return len(result.Params) > 0 && !hasEmptyParamValue(result) && !isEmptyIdentifier(result)
}
func isEmptyIdentifier(result *ResolverResult) bool {
switch strings.ToLower(result.Identifier) {
case "", "null", "''", `""`, "``":
return true
default:
return false
}
}
func isAnyMatchOp(op fexpr.SignOp) bool {
switch op {
case
fexpr.SignAnyEq,
fexpr.SignAnyNeq,
fexpr.SignAnyLike,
fexpr.SignAnyNlike,
fexpr.SignAnyLt,
fexpr.SignAnyLte,
fexpr.SignAnyGt,
fexpr.SignAnyGte:
return true
}
return false
}
// mergeParams returns new dbx.Params where each provided params item
// is merged in the order they are specified.
func mergeParams(params ...dbx.Params) dbx.Params {
result := dbx.Params{}
for _, p := range params {
for k, v := range p {
result[k] = v
}
}
return result
}
// @todo consider adding support for custom single character wildcard
//
// wrapLikeParams wraps each provided param value string with `%`
// if the param doesn't contain an explicit wildcard (`%`) character already.
func wrapLikeParams(params dbx.Params) dbx.Params {
result := dbx.Params{}
for k, v := range params {
vStr := cast.ToString(v)
if !containsUnescapedChar(vStr, '%') {
// note: this is done to minimize the breaking changes and to preserve the original autoescape behavior
vStr = escapeUnescapedChars(vStr, '\\', '%', '_')
vStr = "%" + vStr + "%"
}
result[k] = vStr
}
return result
}
func escapeUnescapedChars(str string, escapeChars ...rune) string {
rs := []rune(str)
total := len(rs)
result := make([]rune, 0, total)
var match bool
for i := total - 1; i >= 0; i-- {
if match {
// check if already escaped
if rs[i] != '\\' {
result = append(result, '\\')
}
match = false
} else {
for _, ec := range escapeChars {
if rs[i] == ec {
match = true
break
}
}
}
result = append(result, rs[i])
// in case the matching char is at the beginning
if i == 0 && match {
result = append(result, '\\')
}
}
// reverse
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
return string(result)
}
func containsUnescapedChar(str string, ch rune) bool {
var prev rune
for _, c := range str {
if c == ch && prev != '\\' {
return true
}
if c == '\\' && prev == '\\' {
prev = rune(0) // reset escape sequence
} else {
prev = c
}
}
return false
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*opExpr)(nil)
// opExpr defines an expression that contains a raw sql operator string.
type opExpr struct {
op string
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *opExpr) Build(db *dbx.DB, params dbx.Params) string {
return e.op
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*concatExpr)(nil)
// concatExpr defines an expression that concatenates multiple
// other expressions with a specified separator.
type concatExpr struct {
separator string
parts []dbx.Expression
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
if len(e.parts) == 0 {
return ""
}
stringParts := make([]string, 0, len(e.parts))
for _, p := range e.parts {
if p == nil {
continue
}
if sql := p.Build(db, params); sql != "" {
stringParts = append(stringParts, sql)
}
}
// skip extra parenthesis for single concat expression
if len(stringParts) == 1 &&
// check for already concatenated raw/plain expressions
!strings.Contains(strings.ToUpper(stringParts[0]), " AND ") &&
!strings.Contains(strings.ToUpper(stringParts[0]), " OR ") {
return stringParts[0]
}
return "(" + strings.Join(stringParts, e.separator) + ")"
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*manyVsManyExpr)(nil)
// manyVsManyExpr constructs a multi-match many<->many db where expression.
//
// Expects leftSubQuery and rightSubQuery to return a subquery with a
// single "multiMatchValue" column.
type manyVsManyExpr struct {
left *ResolverResult
right *ResolverResult
op fexpr.SignOp
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string {
if e.left.MultiMatchSubQuery == nil || e.right.MultiMatchSubQuery == nil {
return "0=1"
}
lAlias := "__ml" + security.PseudorandomString(8)
rAlias := "__mr" + security.PseudorandomString(8)
whereExpr, buildErr := buildResolversExpr(
&ResolverResult{
NoCoalesce: e.left.NoCoalesce,
Identifier: "[[" + lAlias + ".multiMatchValue]]",
},
e.op,
&ResolverResult{
NoCoalesce: e.right.NoCoalesce,
Identifier: "[[" + rAlias + ".multiMatchValue]]",
// note: the AfterBuild needs to be handled only once and it
// doesn't matter whether it is applied on the left or right subquery operand
AfterBuild: dbx.Not, // inverse for the not-exist expression
},
)
if buildErr != nil {
return "0=1"
}
return fmt.Sprintf(
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} LEFT JOIN (%s) {{%s}} WHERE %s)",
e.left.MultiMatchSubQuery.Build(db, params),
lAlias,
e.right.MultiMatchSubQuery.Build(db, params),
rAlias,
whereExpr.Build(db, params),
)
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*manyVsOneExpr)(nil)
// manyVsOneExpr constructs a multi-match many<->one db where expression.
//
// Expects subQuery to return a subquery with a single "multiMatchValue" column.
//
// You can set inverse=false to reverse the condition sides (aka. one<->many).
type manyVsOneExpr struct {
otherOperand *ResolverResult
subQuery dbx.Expression
op fexpr.SignOp
inverse bool
noCoalesce bool
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string {
if e.subQuery == nil {
return "0=1"
}
alias := "__sm" + security.PseudorandomString(8)
r1 := &ResolverResult{
NoCoalesce: e.noCoalesce,
Identifier: "[[" + alias + ".multiMatchValue]]",
AfterBuild: dbx.Not, // inverse for the not-exist expression
}
r2 := &ResolverResult{
Identifier: e.otherOperand.Identifier,
Params: e.otherOperand.Params,
}
var whereExpr dbx.Expression
var buildErr error
if e.inverse {
whereExpr, buildErr = buildResolversExpr(r2, e.op, r1)
} else {
whereExpr, buildErr = buildResolversExpr(r1, e.op, r2)
}
if buildErr != nil {
return "0=1"
}
return fmt.Sprintf(
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} WHERE %s)",
e.subQuery.Build(db, params),
alias,
whereExpr.Build(db, params),
)
}

341
tools/search/filter_test.go Normal file
View file

@ -0,0 +1,341 @@
package search_test
import (
"context"
"database/sql"
"fmt"
"regexp"
"strings"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/search"
)
func TestFilterDataBuildExpr(t *testing.T) {
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", `^test4_\w+$`, `^test5\.[\w\.\:]*\w+$`)
scenarios := []struct {
name string
filterData search.FilterData
expectError bool
expectPattern string
}{
{
"empty",
"",
true,
"",
},
{
"invalid format",
"(test1 > 1",
true,
"",
},
{
"invalid operator",
"test1 + 123",
true,
"",
},
{
"unknown field",
"test1 = 'example' && unknown > 1",
true,
"",
},
{
"simple expression",
"test1 > 1",
false,
"[[test1]] > {:TEST}",
},
{
"empty string vs null",
"'' = null && null != ''",
false,
"('' = '' AND '' IS NOT '')",
},
{
"like with 2 columns",
"test1 ~ test2",
false,
"[[test1]] LIKE ('%' || [[test2]] || '%') ESCAPE '\\'",
},
{
"like with right column operand",
"'lorem' ~ test1",
false,
"{:TEST} LIKE ('%' || [[test1]] || '%') ESCAPE '\\'",
},
{
"like with left column operand and text as right operand",
"test1 ~ 'lorem'",
false,
"[[test1]] LIKE {:TEST} ESCAPE '\\'",
},
{
"not like with 2 columns",
"test1 !~ test2",
false,
"[[test1]] NOT LIKE ('%' || [[test2]] || '%') ESCAPE '\\'",
},
{
"not like with right column operand",
"'lorem' !~ test1",
false,
"{:TEST} NOT LIKE ('%' || [[test1]] || '%') ESCAPE '\\'",
},
{
"like with left column operand and text as right operand",
"test1 !~ 'lorem'",
false,
"[[test1]] NOT LIKE {:TEST} ESCAPE '\\'",
},
{
"nested json no coalesce",
"test5.a = test5.b || test5.c != test5.d",
false,
"(JSON_EXTRACT([[test5]], '$.a') IS JSON_EXTRACT([[test5]], '$.b') OR JSON_EXTRACT([[test5]], '$.c') IS NOT JSON_EXTRACT([[test5]], '$.d'))",
},
{
"macros",
`
test4_1 > @now &&
test4_2 > @second &&
test4_3 > @minute &&
test4_4 > @hour &&
test4_5 > @day &&
test4_6 > @year &&
test4_7 > @month &&
test4_9 > @weekday &&
test4_9 > @todayStart &&
test4_10 > @todayEnd &&
test4_11 > @monthStart &&
test4_12 > @monthEnd &&
test4_13 > @yearStart &&
test4_14 > @yearEnd
`,
false,
"([[test4_1]] > {:TEST} AND [[test4_2]] > {:TEST} AND [[test4_3]] > {:TEST} AND [[test4_4]] > {:TEST} AND [[test4_5]] > {:TEST} AND [[test4_6]] > {:TEST} AND [[test4_7]] > {:TEST} AND [[test4_9]] > {:TEST} AND [[test4_9]] > {:TEST} AND [[test4_10]] > {:TEST} AND [[test4_11]] > {:TEST} AND [[test4_12]] > {:TEST} AND [[test4_13]] > {:TEST} AND [[test4_14]] > {:TEST})",
},
{
"complex expression",
"((test1 > 1) || (test2 != 2)) && test3 ~ '%%example' && test4_sub = null",
false,
"(([[test1]] > {:TEST} OR [[test2]] IS NOT {:TEST}) AND [[test3]] LIKE {:TEST} ESCAPE '\\' AND ([[test4_sub]] = '' OR [[test4_sub]] IS NULL))",
},
{
"combination of special literals (null, true, false)",
"test1=true && test2 != false && null = test3 || null != test4_sub",
false,
"([[test1]] = 1 AND [[test2]] IS NOT 0 AND ('' = [[test3]] OR [[test3]] IS NULL) OR ('' IS NOT [[test4_sub]] AND [[test4_sub]] IS NOT NULL))",
},
{
"all operators",
"(test1 = test2 || test2 != test3) && (test2 ~ 'example' || test2 !~ '%%abc') && 'switch1%%' ~ test1 && 'switch2' !~ test2 && test3 > 1 && test3 >= 0 && test3 <= 4 && 2 < 5",
false,
"((COALESCE([[test1]], '') = COALESCE([[test2]], '') OR COALESCE([[test2]], '') IS NOT COALESCE([[test3]], '')) AND ([[test2]] LIKE {:TEST} ESCAPE '\\' OR [[test2]] NOT LIKE {:TEST} ESCAPE '\\') AND {:TEST} LIKE ('%' || [[test1]] || '%') ESCAPE '\\' AND {:TEST} NOT LIKE ('%' || [[test2]] || '%') ESCAPE '\\' AND [[test3]] > {:TEST} AND [[test3]] >= {:TEST} AND [[test3]] <= {:TEST} AND {:TEST} < {:TEST})",
},
{
"geoDistance function",
"geoDistance(1,2,3,4) < 567",
false,
"(6371 * acos(cos(radians({:TEST})) * cos(radians({:TEST})) * cos(radians({:TEST}) - radians({:TEST})) + sin(radians({:TEST})) * sin(radians({:TEST})))) < {:TEST}",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
expr, err := s.filterData.BuildExpr(resolver)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
return
}
dummyDB := &dbx.DB{}
rawSql := expr.Build(dummyDB, dbx.Params{})
// replace TEST placeholder with .+ regex pattern
expectPattern := strings.ReplaceAll(
"^"+regexp.QuoteMeta(s.expectPattern)+"$",
"TEST",
`\w+`,
)
pattern := regexp.MustCompile(expectPattern)
if !pattern.MatchString(rawSql) {
t.Fatalf("[%s] Pattern %v don't match with expression: \n%v", s.name, expectPattern, rawSql)
}
})
}
}
func TestFilterDataBuildExprWithParams(t *testing.T) {
// create a dummy db
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
t.Fatal(err)
}
db := dbx.NewFromDB(sqlDB, "sqlite")
calledQueries := []string{}
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}
date, err := time.Parse("2006-01-02", "2023-01-01")
if err != nil {
t.Fatal(err)
}
resolver := search.NewSimpleFieldResolver(`^test\w+$`)
filter := search.FilterData(`
test1 = {:test1} ||
test2 = {:test2} ||
test3a = {:test3} ||
test3b = {:test3} ||
test4 = {:test4} ||
test5 = {:test5} ||
test6 = {:test6} ||
test7 = {:test7} ||
test8 = {:test8} ||
test9 = {:test9} ||
test10 = {:test10} ||
test11 = {:test11} ||
test12 = {:test12}
`)
replacements := []dbx.Params{
{"test1": true},
{"test2": false},
{"test3": 123.456},
{"test4": nil},
{"test5": "", "test6": "simple", "test7": `'single_quotes'`, "test8": `"double_quotes"`, "test9": `escape\"quote`},
{"test10": date},
{"test11": []string{"a", "b", `"quote`}},
{"test12": map[string]any{"a": 123, "b": `quote"`}},
}
expr, err := filter.BuildExpr(resolver, replacements...)
if err != nil {
t.Fatal(err)
}
db.Select().Where(expr).Build().Execute()
if len(calledQueries) != 1 {
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
}
expectedQuery := `SELECT * WHERE ([[test1]] = 1 OR [[test2]] = 0 OR [[test3a]] = 123.456 OR [[test3b]] = 123.456 OR ([[test4]] = '' OR [[test4]] IS NULL) OR [[test5]] = '""' OR [[test6]] = 'simple' OR [[test7]] = '''single_quotes''' OR [[test8]] = '"double_quotes"' OR [[test9]] = 'escape\\"quote' OR [[test10]] = '2023-01-01 00:00:00 +0000 UTC' OR [[test11]] = '["a","b","\\"quote"]' OR [[test12]] = '{"a":123,"b":"quote\\""}')`
if expectedQuery != calledQueries[0] {
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
}
}
func TestFilterDataBuildExprWithLimit(t *testing.T) {
resolver := search.NewSimpleFieldResolver(`^\w+$`)
scenarios := []struct {
limit int
filter search.FilterData
expectError bool
}{
{1, "1 = 1", false},
{0, "1 = 1", true}, // new cache entry should be created
{2, "1 = 1 || 1 = 1", false},
{1, "1 = 1 || 1 = 1", true},
{3, "1 = 1 || 1 = 1", false},
{6, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", false},
{5, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", true},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("limit_%d:%d", i, s.limit), func(t *testing.T) {
_, err := s.filter.BuildExprWithLimit(resolver, s.limit)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
})
}
}
func TestLikeParamsWrapping(t *testing.T) {
// create a dummy db
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
t.Fatal(err)
}
db := dbx.NewFromDB(sqlDB, "sqlite")
calledQueries := []string{}
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}
resolver := search.NewSimpleFieldResolver(`^test\w+$`)
filter := search.FilterData(`
test1 ~ {:p1} ||
test2 ~ {:p2} ||
test3 ~ {:p3} ||
test4 ~ {:p4} ||
test5 ~ {:p5} ||
test6 ~ {:p6} ||
test7 ~ {:p7} ||
test8 ~ {:p8} ||
test9 ~ {:p9} ||
test10 ~ {:p10} ||
test11 ~ {:p11} ||
test12 ~ {:p12}
`)
replacements := []dbx.Params{
{"p1": `abc`},
{"p2": `ab%c`},
{"p3": `ab\%c`},
{"p4": `%ab\%c`},
{"p5": `ab\\%c`},
{"p6": `ab\\\%c`},
{"p7": `ab_c`},
{"p8": `ab\_c`},
{"p9": `%ab_c`},
{"p10": `ab\c`},
{"p11": `_ab\c_`},
{"p12": `ab\c%`},
}
expr, err := filter.BuildExpr(resolver, replacements...)
if err != nil {
t.Fatal(err)
}
db.Select().Where(expr).Build().Execute()
if len(calledQueries) != 1 {
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
}
expectedQuery := `SELECT * WHERE ([[test1]] LIKE '%abc%' ESCAPE '\' OR [[test2]] LIKE 'ab%c' ESCAPE '\' OR [[test3]] LIKE 'ab\\%c' ESCAPE '\' OR [[test4]] LIKE '%ab\\%c' ESCAPE '\' OR [[test5]] LIKE 'ab\\\\%c' ESCAPE '\' OR [[test6]] LIKE 'ab\\\\\\%c' ESCAPE '\' OR [[test7]] LIKE '%ab\_c%' ESCAPE '\' OR [[test8]] LIKE '%ab\\\_c%' ESCAPE '\' OR [[test9]] LIKE '%ab_c' ESCAPE '\' OR [[test10]] LIKE '%ab\\c%' ESCAPE '\' OR [[test11]] LIKE '%\_ab\\c\_%' ESCAPE '\' OR [[test12]] LIKE 'ab\\c%' ESCAPE '\')`
if expectedQuery != calledQueries[0] {
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
}
}

View file

@ -0,0 +1,135 @@
package search
import (
"fmt"
"time"
"github.com/pocketbase/pocketbase/tools/types"
)
// note: used primarily for the tests
var timeNow = func() time.Time {
return time.Now()
}
var identifierMacros = map[string]func() (any, error){
"@now": func() (any, error) {
today := timeNow().UTC()
d, err := types.ParseDateTime(today)
if err != nil {
return "", fmt.Errorf("@now: %w", err)
}
return d.String(), nil
},
"@yesterday": func() (any, error) {
yesterday := timeNow().UTC().AddDate(0, 0, -1)
d, err := types.ParseDateTime(yesterday)
if err != nil {
return "", fmt.Errorf("@yesterday: %w", err)
}
return d.String(), nil
},
"@tomorrow": func() (any, error) {
tomorrow := timeNow().UTC().AddDate(0, 0, 1)
d, err := types.ParseDateTime(tomorrow)
if err != nil {
return "", fmt.Errorf("@tomorrow: %w", err)
}
return d.String(), nil
},
"@second": func() (any, error) {
return timeNow().UTC().Second(), nil
},
"@minute": func() (any, error) {
return timeNow().UTC().Minute(), nil
},
"@hour": func() (any, error) {
return timeNow().UTC().Hour(), nil
},
"@day": func() (any, error) {
return timeNow().UTC().Day(), nil
},
"@month": func() (any, error) {
return int(timeNow().UTC().Month()), nil
},
"@weekday": func() (any, error) {
return int(timeNow().UTC().Weekday()), nil
},
"@year": func() (any, error) {
return timeNow().UTC().Year(), nil
},
"@todayStart": func() (any, error) {
today := timeNow().UTC()
start := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, time.UTC)
d, err := types.ParseDateTime(start)
if err != nil {
return "", fmt.Errorf("@todayStart: %w", err)
}
return d.String(), nil
},
"@todayEnd": func() (any, error) {
today := timeNow().UTC()
start := time.Date(today.Year(), today.Month(), today.Day(), 23, 59, 59, 999999999, time.UTC)
d, err := types.ParseDateTime(start)
if err != nil {
return "", fmt.Errorf("@todayEnd: %w", err)
}
return d.String(), nil
},
"@monthStart": func() (any, error) {
today := timeNow().UTC()
start := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
d, err := types.ParseDateTime(start)
if err != nil {
return "", fmt.Errorf("@monthStart: %w", err)
}
return d.String(), nil
},
"@monthEnd": func() (any, error) {
today := timeNow().UTC()
start := time.Date(today.Year(), today.Month(), 1, 23, 59, 59, 999999999, time.UTC)
end := start.AddDate(0, 1, -1)
d, err := types.ParseDateTime(end)
if err != nil {
return "", fmt.Errorf("@monthEnd: %w", err)
}
return d.String(), nil
},
"@yearStart": func() (any, error) {
today := timeNow().UTC()
start := time.Date(today.Year(), 1, 1, 0, 0, 0, 0, time.UTC)
d, err := types.ParseDateTime(start)
if err != nil {
return "", fmt.Errorf("@yearStart: %w", err)
}
return d.String(), nil
},
"@yearEnd": func() (any, error) {
today := timeNow().UTC()
end := time.Date(today.Year(), 12, 31, 23, 59, 59, 999999999, time.UTC)
d, err := types.ParseDateTime(end)
if err != nil {
return "", fmt.Errorf("@yearEnd: %w", err)
}
return d.String(), nil
},
}

View file

@ -0,0 +1,58 @@
package search
import (
"testing"
"time"
)
func TestIdentifierMacros(t *testing.T) {
originalTimeNow := timeNow
timeNow = func() time.Time {
return time.Date(2023, 2, 3, 4, 5, 6, 7, time.UTC)
}
testMacros := map[string]any{
"@now": "2023-02-03 04:05:06.000Z",
"@yesterday": "2023-02-02 04:05:06.000Z",
"@tomorrow": "2023-02-04 04:05:06.000Z",
"@second": 6,
"@minute": 5,
"@hour": 4,
"@day": 3,
"@month": 2,
"@weekday": 5,
"@year": 2023,
"@todayStart": "2023-02-03 00:00:00.000Z",
"@todayEnd": "2023-02-03 23:59:59.999Z",
"@monthStart": "2023-02-01 00:00:00.000Z",
"@monthEnd": "2023-02-28 23:59:59.999Z",
"@yearStart": "2023-01-01 00:00:00.000Z",
"@yearEnd": "2023-12-31 23:59:59.999Z",
}
if len(testMacros) != len(identifierMacros) {
t.Fatalf("Expected %d macros, got %d", len(testMacros), len(identifierMacros))
}
for key, expected := range testMacros {
t.Run(key, func(t *testing.T) {
macro, ok := identifierMacros[key]
if !ok {
t.Fatalf("Missing macro %s", key)
}
result, err := macro()
if err != nil {
t.Fatal(err)
}
if result != expected {
t.Fatalf("Expected %q, got %q", expected, result)
}
})
}
// restore
timeNow = originalTimeNow
}

361
tools/search/provider.go Normal file
View file

@ -0,0 +1,361 @@
package search
import (
"errors"
"math"
"net/url"
"strconv"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/inflector"
"golang.org/x/sync/errgroup"
)
const (
// DefaultPerPage specifies the default number of returned search result items.
DefaultPerPage int = 30
// DefaultFilterExprLimit specifies the default filter expressions limit.
DefaultFilterExprLimit int = 200
// DefaultSortExprLimit specifies the default sort expressions limit.
DefaultSortExprLimit int = 8
// MaxPerPage specifies the max allowed search result items returned in a single page.
MaxPerPage int = 1000
// MaxFilterLength specifies the max allowed individual search filter parsable length.
MaxFilterLength int = 3500
// MaxSortFieldLength specifies the max allowed individual sort field parsable length.
MaxSortFieldLength int = 255
)
// Common search errors.
var (
ErrEmptyQuery = errors.New("search query is not set")
ErrSortExprLimit = errors.New("max sort expressions limit reached")
ErrFilterExprLimit = errors.New("max filter expressions limit reached")
ErrFilterLengthLimit = errors.New("max filter length limit reached")
ErrSortFieldLengthLimit = errors.New("max sort field length limit reached")
)
// URL search query params
const (
PageQueryParam string = "page"
PerPageQueryParam string = "perPage"
SortQueryParam string = "sort"
FilterQueryParam string = "filter"
SkipTotalQueryParam string = "skipTotal"
)
// Result defines the returned search result structure.
type Result struct {
Items any `json:"items"`
Page int `json:"page"`
PerPage int `json:"perPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
}
// Provider represents a single configured search provider instance.
type Provider struct {
fieldResolver FieldResolver
query *dbx.SelectQuery
countCol string
sort []SortField
filter []FilterData
page int
perPage int
skipTotal bool
maxFilterExprLimit int
maxSortExprLimit int
}
// NewProvider initializes and returns a new search provider.
//
// Example:
//
// baseQuery := db.Select("*").From("user")
// fieldResolver := search.NewSimpleFieldResolver("id", "name")
// models := []*YourDataStruct{}
//
// result, err := search.NewProvider(fieldResolver).
// Query(baseQuery).
// ParseAndExec("page=2&filter=id>0&sort=-email", &models)
func NewProvider(fieldResolver FieldResolver) *Provider {
return &Provider{
fieldResolver: fieldResolver,
countCol: "id",
page: 1,
perPage: DefaultPerPage,
sort: []SortField{},
filter: []FilterData{},
maxFilterExprLimit: DefaultFilterExprLimit,
maxSortExprLimit: DefaultSortExprLimit,
}
}
// MaxFilterExprLimit changes the default max allowed filter expressions.
//
// Note that currently the limit is applied individually for each separate filter.
func (s *Provider) MaxFilterExprLimit(max int) *Provider {
s.maxFilterExprLimit = max
return s
}
// MaxSortExprLimit changes the default max allowed sort expressions.
func (s *Provider) MaxSortExprLimit(max int) *Provider {
s.maxSortExprLimit = max
return s
}
// Query sets the base query that will be used to fetch the search items.
func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
s.query = query
return s
}
// SkipTotal changes the `skipTotal` field of the current search provider.
func (s *Provider) SkipTotal(skipTotal bool) *Provider {
s.skipTotal = skipTotal
return s
}
// CountCol allows changing the default column (id) that is used
// to generate the COUNT SQL query statement.
//
// This field is ignored if skipTotal is true.
func (s *Provider) CountCol(name string) *Provider {
s.countCol = name
return s
}
// Page sets the `page` field of the current search provider.
//
// Normalization on the `page` value is done during `Exec()`.
func (s *Provider) Page(page int) *Provider {
s.page = page
return s
}
// PerPage sets the `perPage` field of the current search provider.
//
// Normalization on the `perPage` value is done during `Exec()`.
func (s *Provider) PerPage(perPage int) *Provider {
s.perPage = perPage
return s
}
// Sort sets the `sort` field of the current search provider.
func (s *Provider) Sort(sort []SortField) *Provider {
s.sort = sort
return s
}
// AddSort appends the provided SortField to the existing provider's sort field.
func (s *Provider) AddSort(field SortField) *Provider {
s.sort = append(s.sort, field)
return s
}
// Filter sets the `filter` field of the current search provider.
func (s *Provider) Filter(filter []FilterData) *Provider {
s.filter = filter
return s
}
// AddFilter appends the provided FilterData to the existing provider's filter field.
func (s *Provider) AddFilter(filter FilterData) *Provider {
if filter != "" {
s.filter = append(s.filter, filter)
}
return s
}
// Parse parses the search query parameter from the provided query string
// and assigns the found fields to the current search provider.
//
// The data from the "sort" and "filter" query parameters are appended
// to the existing provider's `sort` and `filter` fields
// (aka. using `AddSort` and `AddFilter`).
func (s *Provider) Parse(urlQuery string) error {
params, err := url.ParseQuery(urlQuery)
if err != nil {
return err
}
if raw := params.Get(SkipTotalQueryParam); raw != "" {
v, err := strconv.ParseBool(raw)
if err != nil {
return err
}
s.SkipTotal(v)
}
if raw := params.Get(PageQueryParam); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil {
return err
}
s.Page(v)
}
if raw := params.Get(PerPageQueryParam); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil {
return err
}
s.PerPage(v)
}
if raw := params.Get(SortQueryParam); raw != "" {
for _, sortField := range ParseSortFromString(raw) {
s.AddSort(sortField)
}
}
if raw := params.Get(FilterQueryParam); raw != "" {
s.AddFilter(FilterData(raw))
}
return nil
}
// Exec executes the search provider and fills/scans
// the provided `items` slice with the found models.
func (s *Provider) Exec(items any) (*Result, error) {
if s.query == nil {
return nil, ErrEmptyQuery
}
// shallow clone the provider's query
modelsQuery := *s.query
// build filters
for _, f := range s.filter {
if len(f) > MaxFilterLength {
return nil, ErrFilterLengthLimit
}
expr, err := f.BuildExprWithLimit(s.fieldResolver, s.maxFilterExprLimit)
if err != nil {
return nil, err
}
if expr != nil {
modelsQuery.AndWhere(expr)
}
}
// apply sorting
if len(s.sort) > s.maxSortExprLimit {
return nil, ErrSortExprLimit
}
for _, sortField := range s.sort {
if len(sortField.Name) > MaxSortFieldLength {
return nil, ErrSortFieldLengthLimit
}
expr, err := sortField.BuildExpr(s.fieldResolver)
if err != nil {
return nil, err
}
if expr != "" {
// ensure that _rowid_ expressions are always prefixed with the first FROM table
if sortField.Name == rowidSortKey && !strings.Contains(expr, ".") {
queryInfo := modelsQuery.Info()
if len(queryInfo.From) > 0 {
expr = "[[" + inflector.Columnify(queryInfo.From[0]) + "]]." + expr
}
}
modelsQuery.AndOrderBy(expr)
}
}
// apply field resolver query modifications (if any)
if err := s.fieldResolver.UpdateQuery(&modelsQuery); err != nil {
return nil, err
}
// normalize page
if s.page <= 0 {
s.page = 1
}
// normalize perPage
if s.perPage <= 0 {
s.perPage = DefaultPerPage
} else if s.perPage > MaxPerPage {
s.perPage = MaxPerPage
}
// negative value to differentiate from the zero default
totalCount := -1
totalPages := -1
// prepare a count query from the base one
countQuery := modelsQuery // shallow clone
countExec := func() error {
queryInfo := countQuery.Info()
countCol := s.countCol
if len(queryInfo.From) > 0 {
countCol = queryInfo.From[0] + "." + countCol
}
// note: countQuery is shallow cloned and slice/map in-place modifications should be avoided
err := countQuery.Distinct(false).
Select("COUNT(DISTINCT [[" + countCol + "]])").
OrderBy( /* reset */ ).
Row(&totalCount)
if err != nil {
return err
}
totalPages = int(math.Ceil(float64(totalCount) / float64(s.perPage)))
return nil
}
// apply pagination to the original query and fetch the models
modelsExec := func() error {
modelsQuery.Limit(int64(s.perPage))
modelsQuery.Offset(int64(s.perPage * (s.page - 1)))
return modelsQuery.All(items)
}
if !s.skipTotal {
// execute the 2 queries concurrently
errg := new(errgroup.Group)
errg.SetLimit(2)
errg.Go(countExec)
errg.Go(modelsExec)
if err := errg.Wait(); err != nil {
return nil, err
}
} else {
if err := modelsExec(); err != nil {
return nil, err
}
}
result := &Result{
Page: s.page,
PerPage: s.perPage,
TotalItems: totalCount,
TotalPages: totalPages,
Items: items,
}
return result, nil
}
// ParseAndExec is a short convenient method to trigger both
// `Parse()` and `Exec()` in a single call.
func (s *Provider) ParseAndExec(urlQuery string, modelsSlice any) (*Result, error) {
if err := s.Parse(urlQuery); err != nil {
return nil, err
}
return s.Exec(modelsSlice)
}

View file

@ -0,0 +1,794 @@
package search
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/list"
_ "modernc.org/sqlite"
)
func TestNewProvider(t *testing.T) {
r := &testFieldResolver{}
p := NewProvider(r)
if p.page != 1 {
t.Fatalf("Expected page %d, got %d", 1, p.page)
}
if p.perPage != DefaultPerPage {
t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
}
if p.maxFilterExprLimit != DefaultFilterExprLimit {
t.Fatalf("Expected maxFilterExprLimit %d, got %d", DefaultFilterExprLimit, p.maxFilterExprLimit)
}
if p.maxSortExprLimit != DefaultSortExprLimit {
t.Fatalf("Expected maxSortExprLimit %d, got %d", DefaultSortExprLimit, p.maxSortExprLimit)
}
}
func TestMaxFilterExprLimit(t *testing.T) {
p := NewProvider(&testFieldResolver{})
testVals := []int{0, -10, 10}
for _, val := range testVals {
t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
p.MaxFilterExprLimit(val)
if p.maxFilterExprLimit != val {
t.Fatalf("Expected maxFilterExprLimit to change to %d, got %d", val, p.maxFilterExprLimit)
}
})
}
}
func TestMaxSortExprLimit(t *testing.T) {
p := NewProvider(&testFieldResolver{})
testVals := []int{0, -10, 10}
for _, val := range testVals {
t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
p.MaxSortExprLimit(val)
if p.maxSortExprLimit != val {
t.Fatalf("Expected maxSortExprLimit to change to %d, got %d", val, p.maxSortExprLimit)
}
})
}
}
func TestProviderQuery(t *testing.T) {
db := dbx.NewFromDB(nil, "")
query := db.Select("id").From("test")
querySql := query.Build().SQL()
r := &testFieldResolver{}
p := NewProvider(r).Query(query)
expected := p.query.Build().SQL()
if querySql != expected {
t.Fatalf("Expected %v, got %v", expected, querySql)
}
}
func TestProviderSkipTotal(t *testing.T) {
p := NewProvider(&testFieldResolver{})
if p.skipTotal {
t.Fatalf("Expected the default skipTotal to be %v, got %v", false, p.skipTotal)
}
p.SkipTotal(true)
if !p.skipTotal {
t.Fatalf("Expected skipTotal to change to %v, got %v", true, p.skipTotal)
}
}
func TestProviderCountCol(t *testing.T) {
p := NewProvider(&testFieldResolver{})
if p.countCol != "id" {
t.Fatalf("Expected the default countCol to be %s, got %s", "id", p.countCol)
}
p.CountCol("test")
if p.countCol != "test" {
t.Fatalf("Expected colCount to change to %s, got %s", "test", p.countCol)
}
}
func TestProviderPage(t *testing.T) {
r := &testFieldResolver{}
p := NewProvider(r).Page(10)
if p.page != 10 {
t.Fatalf("Expected page %v, got %v", 10, p.page)
}
}
func TestProviderPerPage(t *testing.T) {
r := &testFieldResolver{}
p := NewProvider(r).PerPage(456)
if p.perPage != 456 {
t.Fatalf("Expected perPage %v, got %v", 456, p.perPage)
}
}
func TestProviderSort(t *testing.T) {
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
r := &testFieldResolver{}
p := NewProvider(r).
Sort(initialSort).
AddSort(SortField{"test3", SortDesc})
encoded, _ := json.Marshal(p.sort)
expected := `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"test3","direction":"DESC"}]`
if string(encoded) != expected {
t.Fatalf("Expected sort %v, got \n%v", expected, string(encoded))
}
}
func TestProviderFilter(t *testing.T) {
initialFilter := []FilterData{"test1", "test2"}
r := &testFieldResolver{}
p := NewProvider(r).
Filter(initialFilter).
AddFilter("test3")
encoded, _ := json.Marshal(p.filter)
expected := `["test1","test2","test3"]`
if string(encoded) != expected {
t.Fatalf("Expected filter %v, got \n%v", expected, string(encoded))
}
}
func TestProviderParse(t *testing.T) {
initialPage := 2
initialPerPage := 123
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
initialFilter := []FilterData{"test1", "test2"}
scenarios := []struct {
query string
expectError bool
expectPage int
expectPerPage int
expectSort string
expectFilter string
}{
// empty
{
"",
false,
initialPage,
initialPerPage,
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
`["test1","test2"]`,
},
// invalid query
{
"invalid;",
true,
initialPage,
initialPerPage,
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
`["test1","test2"]`,
},
// invalid page
{
"page=a",
true,
initialPage,
initialPerPage,
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
`["test1","test2"]`,
},
// invalid perPage
{
"perPage=a",
true,
initialPage,
initialPerPage,
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
`["test1","test2"]`,
},
// valid query parameters
{
"page=3&perPage=456&filter=test3&sort=-a,b,+c&other=123",
false,
3,
456,
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"a","direction":"DESC"},{"name":"b","direction":"ASC"},{"name":"c","direction":"ASC"}]`,
`["test1","test2","test3"]`,
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.query), func(t *testing.T) {
r := &testFieldResolver{}
p := NewProvider(r).
Page(initialPage).
PerPage(initialPerPage).
Sort(initialSort).
Filter(initialFilter)
err := p.Parse(s.query)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if p.page != s.expectPage {
t.Fatalf("Expected page %v, got %v", s.expectPage, p.page)
}
if p.perPage != s.expectPerPage {
t.Fatalf("Expected perPage %v, got %v", s.expectPerPage, p.perPage)
}
encodedSort, _ := json.Marshal(p.sort)
if string(encodedSort) != s.expectSort {
t.Fatalf("Expected sort %v, got \n%v", s.expectSort, string(encodedSort))
}
encodedFilter, _ := json.Marshal(p.filter)
if string(encodedFilter) != s.expectFilter {
t.Fatalf("Expected filter %v, got \n%v", s.expectFilter, string(encodedFilter))
}
})
}
}
func TestProviderExecEmptyQuery(t *testing.T) {
p := NewProvider(&testFieldResolver{}).
Query(nil)
_, err := p.Exec(&[]testTableStruct{})
if err == nil {
t.Fatalf("Expected error with empty query, got nil")
}
}
func TestProviderExecNonEmptyQuery(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
query := testDB.Select("*").
From("test").
Where(dbx.Not(dbx.HashExp{"test1": nil})).
OrderBy("test1 ASC")
scenarios := []struct {
name string
page int
perPage int
sort []SortField
filter []FilterData
skipTotal bool
expectError bool
expectResult string
expectQueries []string
}{
{
"page normalization",
-1,
10,
[]SortField{},
[]FilterData{},
false,
false,
`{"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":10,"totalItems":2,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
},
},
{
"perPage normalization",
10,
0, // fallback to default
[]SortField{},
[]FilterData{},
false,
false,
`{"items":[],"page":10,"perPage":30,"totalItems":2,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30 OFFSET 270",
},
},
{
"invalid sort field",
1,
10,
[]SortField{{"unknown", SortAsc}},
[]FilterData{},
false,
true,
"",
nil,
},
{
"invalid filter",
1,
10,
[]SortField{},
[]FilterData{"test2 = 'test2.1'", "invalid"},
false,
true,
"",
nil,
},
{
"valid sort and filter fields",
1,
5555, // will be limited by MaxPerPage
[]SortField{{"test2", SortDesc}},
[]FilterData{"test2 != null", "test1 >= 2"},
false,
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2)",
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
},
},
{
"valid sort and filter fields (skipTotal=1)",
1,
5555, // will be limited by MaxPerPage
[]SortField{{"test2", SortDesc}},
[]FilterData{"test2 != null", "test1 >= 2"},
true,
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
},
},
{
"valid sort and filter fields (zero results)",
1,
10,
[]SortField{{"test3", SortAsc}},
[]FilterData{"test3 != ''"},
false,
false,
`{"items":[],"page":1,"perPage":10,"totalItems":0,"totalPages":0}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL)))",
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
},
},
{
"valid sort and filter fields (zero results; skipTotal=1)",
1,
10,
[]SortField{{"test3", SortAsc}},
[]FilterData{"test3 != ''"},
true,
false,
`{"items":[],"page":1,"perPage":10,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
},
},
{
"pagination test",
2,
1,
[]SortField{},
[]FilterData{},
false,
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":2,"totalPages":2}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
},
},
{
"pagination test (skipTotal=1)",
2,
1,
[]SortField{},
[]FilterData{},
true,
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testDB.CalledQueries = []string{} // reset
testResolver := &testFieldResolver{}
p := NewProvider(testResolver).
Query(query).
Page(s.page).
PerPage(s.perPage).
Sort(s.sort).
SkipTotal(s.skipTotal).
Filter(s.filter)
result, err := p.Exec(&[]testTableStruct{})
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if testResolver.UpdateQueryCalls != 1 {
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
}
encoded, _ := json.Marshal(result)
if string(encoded) != s.expectResult {
t.Fatalf("Expected result %v, got \n%v", s.expectResult, string(encoded))
}
if len(s.expectQueries) != len(testDB.CalledQueries) {
t.Fatalf("Expected %d queries, got %d: \n%v", len(s.expectQueries), len(testDB.CalledQueries), testDB.CalledQueries)
}
for _, q := range testDB.CalledQueries {
if !list.ExistInSliceWithRegex(q, s.expectQueries) {
t.Fatalf("Didn't expect query \n%v \nin \n%v", q, s.expectQueries)
}
}
})
}
}
func TestProviderFilterAndSortLimits(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
query := testDB.Select("*").
From("test").
Where(dbx.Not(dbx.HashExp{"test1": nil})).
OrderBy("test1 ASC")
scenarios := []struct {
name string
filter []FilterData
sort []SortField
maxFilterExprLimit int
maxSortExprLimit int
expectError bool
}{
// filter
{
"<= max filter length",
[]FilterData{
"1=2",
FilterData("1='" + strings.Repeat("a", MaxFilterLength-4) + "'"),
},
[]SortField{},
1,
0,
false,
},
{
"> max filter length",
[]FilterData{
"1=2",
FilterData("1='" + strings.Repeat("a", MaxFilterLength-3) + "'"),
},
[]SortField{},
1,
0,
true,
},
{
"<= max filter exprs",
[]FilterData{
"1=2",
"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
},
[]SortField{},
6,
0,
false,
},
{
"> max filter exprs",
[]FilterData{
"1=2",
"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
},
[]SortField{},
5,
0,
true,
},
// sort
{
"<= max sort field length",
[]FilterData{},
[]SortField{
{"id", SortAsc},
{"test1", SortDesc},
{strings.Repeat("a", MaxSortFieldLength), SortDesc},
},
0,
10,
false,
},
{
"> max sort field length",
[]FilterData{},
[]SortField{
{"id", SortAsc},
{"test1", SortDesc},
{strings.Repeat("b", MaxSortFieldLength+1), SortDesc},
},
0,
10,
true,
},
{
"<= max sort exprs",
[]FilterData{},
[]SortField{
{"id", SortAsc},
{"test1", SortDesc},
},
0,
2,
false,
},
{
"> max sort exprs",
[]FilterData{},
[]SortField{
{"id", SortAsc},
{"test1", SortDesc},
},
0,
1,
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testResolver := &testFieldResolver{}
p := NewProvider(testResolver).
Query(query).
Sort(s.sort).
Filter(s.filter).
MaxFilterExprLimit(s.maxFilterExprLimit).
MaxSortExprLimit(s.maxSortExprLimit)
_, err := p.Exec(&[]testTableStruct{})
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
})
}
}
func TestProviderParseAndExec(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
query := testDB.Select("*").
From("test").
Where(dbx.Not(dbx.HashExp{"test1": nil})).
OrderBy("test1 ASC")
scenarios := []struct {
name string
queryString string
expectError bool
expectResult string
}{
{
"no extra query params (aka. use the provider presets)",
"",
false,
`{"items":[],"page":2,"perPage":123,"totalItems":2,"totalPages":1}`,
},
{
"invalid query",
"invalid;",
true,
"",
},
{
"invalid page",
"page=a",
true,
"",
},
{
"invalid perPage",
"perPage=a",
true,
"",
},
{
"invalid skipTotal",
"skipTotal=a",
true,
"",
},
{
"invalid sorting field",
"sort=-unknown",
true,
"",
},
{
"invalid filter field",
"filter=unknown>1",
true,
"",
},
{
"page > existing",
"page=3&perPage=9999",
false,
`{"items":[],"page":3,"perPage":1000,"totalItems":2,"totalPages":1}`,
},
{
"valid query params",
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3",
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":1,"totalPages":1}`,
},
{
"valid query params with skipTotal=1",
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3&skipTotal=1",
false,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":-1,"totalPages":-1}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testDB.CalledQueries = []string{} // reset
testResolver := &testFieldResolver{}
provider := NewProvider(testResolver).
Query(query).
Page(2).
PerPage(123).
Sort([]SortField{{"test2", SortAsc}}).
Filter([]FilterData{"test1 > 0"})
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if testResolver.UpdateQueryCalls != 1 {
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
}
expectedQueries := 2
if provider.skipTotal {
expectedQueries = 1
}
if len(testDB.CalledQueries) != expectedQueries {
t.Fatalf("Expected %d db queries, got %d: \n%v", expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
}
encoded, _ := json.Marshal(result)
if string(encoded) != s.expectResult {
t.Fatalf("Expected result \n%v\ngot\n%v", s.expectResult, string(encoded))
}
})
}
}
// -------------------------------------------------------------------
// Helpers
// -------------------------------------------------------------------
type testTableStruct struct {
Test1 int `db:"test1" json:"test1"`
Test2 string `db:"test2" json:"test2"`
Test3 string `db:"test3" json:"test3"`
}
type testDB struct {
*dbx.DB
CalledQueries []string
}
// NB! Don't forget to call `db.Close()` at the end of the test.
func createTestDB() (*testDB, error) {
// using a shared cache to allow multiple connections access to
// the same in memory database https://www.sqlite.org/inmemorydb.html
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
return nil, err
}
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
db.CreateTable("test", map[string]string{
"id": "int default 0",
"test1": "int default 0",
"test2": "text default ''",
"test3": "text default ''",
strings.Repeat("a", MaxSortFieldLength): "text default ''",
strings.Repeat("b", MaxSortFieldLength+1): "text default ''",
}).Execute()
db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute()
db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute()
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
db.CalledQueries = append(db.CalledQueries, sql)
}
return &db, nil
}
// ---
type testFieldResolver struct {
UpdateQueryCalls int
ResolveCalls int
}
func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
t.UpdateQueryCalls++
return nil
}
func (t *testFieldResolver) Resolve(field string) (*ResolverResult, error) {
t.ResolveCalls++
if field == "unknown" {
return nil, errors.New("test error")
}
return &ResolverResult{Identifier: field}, nil
}

View file

@ -0,0 +1,113 @@
package search
import (
"fmt"
"strconv"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/list"
)
// ResolverResult defines a single FieldResolver.Resolve() successfully parsed result.
type ResolverResult struct {
// Identifier is the plain SQL identifier/column that will be used
// in the final db expression as left or right operand.
Identifier string
// NoCoalesce instructs to not use COALESCE or NULL fallbacks
// when building the identifier expression.
NoCoalesce bool
// Params is a map with db placeholder->value pairs that will be added
// to the query when building both resolved operands/sides in a single expression.
Params dbx.Params
// MultiMatchSubQuery is an optional sub query expression that will be added
// in addition to the combined ResolverResult expression during build.
MultiMatchSubQuery dbx.Expression
// AfterBuild is an optional function that will be called after building
// and combining the result of both resolved operands/sides in a single expression.
AfterBuild func(expr dbx.Expression) dbx.Expression
}
// FieldResolver defines an interface for managing search fields.
type FieldResolver interface {
// UpdateQuery allows to updated the provided db query based on the
// resolved search fields (eg. adding joins aliases, etc.).
//
// Called internally by `search.Provider` before executing the search request.
UpdateQuery(query *dbx.SelectQuery) error
// Resolve parses the provided field and returns a properly
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
Resolve(field string) (*ResolverResult, error)
}
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
// provided `allowedFields`.
//
// Each `allowedFields` could be a plain string (eg. "name")
// or a regexp pattern (eg. `^\w+[\w\.]*$`).
func NewSimpleFieldResolver(allowedFields ...string) *SimpleFieldResolver {
return &SimpleFieldResolver{
allowedFields: allowedFields,
}
}
// SimpleFieldResolver defines a generic search resolver that allows
// only its listed fields to be resolved and take part in a search query.
//
// If `allowedFields` are empty no fields filtering is applied.
type SimpleFieldResolver struct {
allowedFields []string
}
// UpdateQuery implements `search.UpdateQuery` interface.
func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
// nothing to update...
return nil
}
// Resolve implements `search.Resolve` interface.
//
// Returns error if `field` is not in `r.allowedFields`.
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
return nil, fmt.Errorf("failed to resolve field %q", field)
}
parts := strings.Split(field, ".")
// single regular field
if len(parts) == 1 {
return &ResolverResult{
Identifier: "[[" + inflector.Columnify(parts[0]) + "]]",
}, nil
}
// treat as json path
var jsonPath strings.Builder
jsonPath.WriteString("$")
for _, part := range parts[1:] {
if _, err := strconv.Atoi(part); err == nil {
jsonPath.WriteString("[")
jsonPath.WriteString(inflector.Columnify(part))
jsonPath.WriteString("]")
} else {
jsonPath.WriteString(".")
jsonPath.WriteString(inflector.Columnify(part))
}
}
return &ResolverResult{
NoCoalesce: true,
Identifier: fmt.Sprintf(
"JSON_EXTRACT([[%s]], '%s')",
inflector.Columnify(parts[0]),
jsonPath.String(),
),
}, nil
}

View file

@ -0,0 +1,87 @@
package search_test
import (
"fmt"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/search"
)
func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
r := search.NewSimpleFieldResolver("test")
scenarios := []struct {
fieldName string
expectQuery string
}{
// missing field (the query shouldn't change)
{"", `SELECT "id" FROM "test"`},
// unknown field (the query shouldn't change)
{"unknown", `SELECT "id" FROM "test"`},
// allowed field (the query shouldn't change)
{"test", `SELECT "id" FROM "test"`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
db := dbx.NewFromDB(nil, "")
query := db.Select("id").From("test")
r.Resolve(s.fieldName)
if err := r.UpdateQuery(nil); err != nil {
t.Fatalf("UpdateQuery failed with error %v", err)
}
rawQuery := query.Build().SQL()
if rawQuery != s.expectQuery {
t.Fatalf("Expected query %v, got \n%v", s.expectQuery, rawQuery)
}
})
}
}
func TestSimpleFieldResolverResolve(t *testing.T) {
r := search.NewSimpleFieldResolver("test", `^test_regex\d+$`, "Test columnify!", "data.test")
scenarios := []struct {
fieldName string
expectError bool
expectName string
}{
{"", true, ""},
{" ", true, ""},
{"unknown", true, ""},
{"test", false, "[[test]]"},
{"test.sub", true, ""},
{"test_regex", true, ""},
{"test_regex1", false, "[[test_regex1]]"},
{"Test columnify!", false, "[[Testcolumnify]]"},
{"data.test", false, "JSON_EXTRACT([[data]], '$.test')"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
r, err := r.Resolve(s.fieldName)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if r.Identifier != s.expectName {
t.Fatalf("Expected r.Identifier %q, got %q", s.expectName, r.Identifier)
}
if len(r.Params) != 0 {
t.Fatalf("r.Params should be empty, got %v", r.Params)
}
})
}
}

67
tools/search/sort.go Normal file
View file

@ -0,0 +1,67 @@
package search
import (
"fmt"
"strings"
)
const (
randomSortKey string = "@random"
rowidSortKey string = "@rowid"
)
// sort field directions
const (
SortAsc string = "ASC"
SortDesc string = "DESC"
)
// SortField defines a single search sort field.
type SortField struct {
Name string `json:"name"`
Direction string `json:"direction"`
}
// BuildExpr resolves the sort field into a valid db sort expression.
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
// special case for random sort
if s.Name == randomSortKey {
return "RANDOM()", nil
}
// special case for the builtin SQLite rowid column
if s.Name == rowidSortKey {
return fmt.Sprintf("[[_rowid_]] %s", s.Direction), nil
}
result, err := fieldResolver.Resolve(s.Name)
// invalidate empty fields and non-column identifiers
if err != nil || len(result.Params) > 0 || result.Identifier == "" || strings.ToLower(result.Identifier) == "null" {
return "", fmt.Errorf("invalid sort field %q", s.Name)
}
return fmt.Sprintf("%s %s", result.Identifier, s.Direction), nil
}
// ParseSortFromString parses the provided string expression
// into a slice of SortFields.
//
// Example:
//
// fields := search.ParseSortFromString("-name,+created")
func ParseSortFromString(str string) (fields []SortField) {
data := strings.Split(str, ",")
for _, field := range data {
// trim whitespaces
field = strings.TrimSpace(field)
if strings.HasPrefix(field, "-") {
fields = append(fields, SortField{strings.TrimPrefix(field, "-"), SortDesc})
} else {
fields = append(fields, SortField{strings.TrimPrefix(field, "+"), SortAsc})
}
}
return
}

78
tools/search/sort_test.go Normal file
View file

@ -0,0 +1,78 @@
package search_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/search"
)
func TestSortFieldBuildExpr(t *testing.T) {
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
scenarios := []struct {
sortField search.SortField
expectError bool
expectExpression string
}{
// empty
{search.SortField{"", search.SortDesc}, true, ""},
// unknown field
{search.SortField{"unknown", search.SortAsc}, true, ""},
// placeholder field
{search.SortField{"'test'", search.SortAsc}, true, ""},
// null field
{search.SortField{"null", search.SortAsc}, true, ""},
// allowed field - asc
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
// allowed field - desc
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
// special @random field (ignore direction)
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
// special _rowid_ field
{search.SortField{"@rowid", search.SortDesc}, false, "[[_rowid_]] DESC"},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("%s_%s", s.sortField.Name, s.sortField.Name), func(t *testing.T) {
result, err := s.sortField.BuildExpr(resolver)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if result != s.expectExpression {
t.Fatalf("Expected expression %v, got %v", s.expectExpression, result)
}
})
}
}
func TestParseSortFromString(t *testing.T) {
scenarios := []struct {
value string
expected string
}{
{"", `[{"name":"","direction":"ASC"}]`},
{"test", `[{"name":"test","direction":"ASC"}]`},
{"+test", `[{"name":"test","direction":"ASC"}]`},
{"-test", `[{"name":"test","direction":"DESC"}]`},
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
{"-@rowid,-test", `[{"name":"@rowid","direction":"DESC"},{"name":"test","direction":"DESC"}]`},
}
for _, s := range scenarios {
t.Run(s.value, func(t *testing.T) {
result := search.ParseSortFromString(s.value)
encoded, _ := json.Marshal(result)
encodedStr := string(encoded)
if encodedStr != s.expected {
t.Fatalf("Expected expression %s, got %s", s.expected, encodedStr)
}
})
}
}

View file

@ -0,0 +1,56 @@
package search
import (
"fmt"
"github.com/ganigeorgiev/fexpr"
)
var TokenFunctions = map[string]func(
argTokenResolverFunc func(fexpr.Token) (*ResolverResult, error),
args ...fexpr.Token,
) (*ResolverResult, error){
// geoDistance(lonA, latA, lonB, latB) calculates the Haversine
// distance between 2 points in kilometres (https://www.movable-type.co.uk/scripts/latlong.html).
//
// The accepted arguments at the moment could be either a plain number or a column identifier (including NULL).
// If the column identifier cannot be resolved and converted to a numeric value, it resolves to NULL.
//
// Similar to the built-in SQLite functions, geoDistance doesn't apply
// a "match-all" constraints in case there are multiple relation fields arguments.
// Or in other words, if a collection has "orgs" multiple relation field pointing to "orgs" collection that has "office" as "geoPoint" field,
// then the filter: `geoDistance(orgs.office.lon, orgs.office.lat, 1, 2) < 200`
// will evaluate to true if for at-least-one of the "orgs.office" records the function result in a value satisfying the condition (aka. "result < 200").
"geoDistance": func(argTokenResolverFunc func(fexpr.Token) (*ResolverResult, error), args ...fexpr.Token) (*ResolverResult, error) {
if len(args) != 4 {
return nil, fmt.Errorf("[geoDistance] expected 4 arguments, got %d", len(args))
}
resolvedArgs := make([]*ResolverResult, 4)
for i, arg := range args {
if arg.Type != fexpr.TokenIdentifier && arg.Type != fexpr.TokenNumber {
return nil, fmt.Errorf("[geoDistance] argument %d must be an identifier or number", i)
}
resolved, err := argTokenResolverFunc(arg)
if err != nil {
return nil, fmt.Errorf("[geoDistance] failed to resolve argument %d: %w", i, err)
}
resolvedArgs[i] = resolved
}
lonA := resolvedArgs[0].Identifier
latA := resolvedArgs[1].Identifier
lonB := resolvedArgs[2].Identifier
latB := resolvedArgs[3].Identifier
return &ResolverResult{
NoCoalesce: true,
Identifier: `(6371 * acos(` +
`cos(radians(` + latA + `)) * cos(radians(` + latB + `)) * ` +
`cos(radians(` + lonB + `) - radians(` + lonA + `)) + ` +
`sin(radians(` + latA + `)) * sin(radians(` + latB + `))` +
`))`,
Params: mergeParams(resolvedArgs[0].Params, resolvedArgs[1].Params, resolvedArgs[2].Params, resolvedArgs[3].Params),
}, nil
},
}

View file

@ -0,0 +1,277 @@
package search
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/ganigeorgiev/fexpr"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/security"
)
func TestTokenFunctionsGeoDistance(t *testing.T) {
t.Parallel()
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
fn, ok := TokenFunctions["geoDistance"]
if !ok {
t.Error("Expected geoDistance token function to be registered.")
}
baseTokenResolver := func(t fexpr.Token) (*ResolverResult, error) {
placeholder := "t" + security.PseudorandomString(5)
return &ResolverResult{Identifier: "{:" + placeholder + "}", Params: map[string]any{placeholder: t.Literal}}, nil
}
scenarios := []struct {
name string
args []fexpr.Token
resolver func(t fexpr.Token) (*ResolverResult, error)
result *ResolverResult
expectErr bool
}{
{
"no args",
nil,
baseTokenResolver,
nil,
true,
},
{
"< 4 args",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenNumber},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
},
baseTokenResolver,
nil,
true,
},
{
"> 4 args",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenNumber},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
{Literal: "4", Type: fexpr.TokenNumber},
{Literal: "5", Type: fexpr.TokenNumber},
},
baseTokenResolver,
nil,
true,
},
{
"unsupported function argument",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenFunction},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
{Literal: "4", Type: fexpr.TokenNumber},
},
baseTokenResolver,
nil,
true,
},
{
"unsupported text argument",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenText},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
{Literal: "4", Type: fexpr.TokenNumber},
},
baseTokenResolver,
nil,
true,
},
{
"4 valid arguments but with resolver error",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenNumber},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
{Literal: "4", Type: fexpr.TokenNumber},
},
func(t fexpr.Token) (*ResolverResult, error) {
return nil, errors.New("test")
},
nil,
true,
},
{
"4 valid arguments",
[]fexpr.Token{
{Literal: "1", Type: fexpr.TokenNumber},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "3", Type: fexpr.TokenNumber},
{Literal: "4", Type: fexpr.TokenNumber},
},
baseTokenResolver,
&ResolverResult{
NoCoalesce: true,
Identifier: `(6371 * acos(cos(radians({:latA})) * cos(radians({:latB})) * cos(radians({:lonB}) - radians({:lonA})) + sin(radians({:latA})) * sin(radians({:latB}))))`,
Params: map[string]any{
"lonA": 1,
"latA": 2,
"lonB": 3,
"latB": 4,
},
},
false,
},
{
"mixed arguments",
[]fexpr.Token{
{Literal: "null", Type: fexpr.TokenIdentifier},
{Literal: "2", Type: fexpr.TokenNumber},
{Literal: "false", Type: fexpr.TokenIdentifier},
{Literal: "4", Type: fexpr.TokenNumber},
},
baseTokenResolver,
&ResolverResult{
NoCoalesce: true,
Identifier: `(6371 * acos(cos(radians({:latA})) * cos(radians({:latB})) * cos(radians({:lonB}) - radians({:lonA})) + sin(radians({:latA})) * sin(radians({:latB}))))`,
Params: map[string]any{
"lonA": "null",
"latA": 2,
"lonB": false,
"latB": 4,
},
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result, err := fn(s.resolver, s.args...)
hasErr := err != nil
if hasErr != s.expectErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
}
testCompareResults(t, s.result, result)
})
}
}
func TestTokenFunctionsGeoDistanceExec(t *testing.T) {
t.Parallel()
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
fn, ok := TokenFunctions["geoDistance"]
if !ok {
t.Error("Expected geoDistance token function to be registered.")
}
result, err := fn(
func(t fexpr.Token) (*ResolverResult, error) {
placeholder := "t" + security.PseudorandomString(5)
return &ResolverResult{Identifier: "{:" + placeholder + "}", Params: map[string]any{placeholder: t.Literal}}, nil
},
fexpr.Token{Literal: "23.23033854945808", Type: fexpr.TokenNumber},
fexpr.Token{Literal: "42.713146090563384", Type: fexpr.TokenNumber},
fexpr.Token{Literal: "23.44920680886216", Type: fexpr.TokenNumber},
fexpr.Token{Literal: "42.7078484153991", Type: fexpr.TokenNumber},
)
if err != nil {
t.Fatal(err)
}
column := []float64{}
err = testDB.NewQuery("select " + result.Identifier).Bind(result.Params).Column(&column)
if err != nil {
t.Fatal(err)
}
if len(column) != 1 {
t.Fatalf("Expected exactly 1 column value as result, got %v", column)
}
expected := "17.89"
distance := fmt.Sprintf("%.2f", column[0])
if distance != expected {
t.Fatalf("Expected distance value %s, got %s", expected, distance)
}
}
// -------------------------------------------------------------------
func testCompareResults(t *testing.T, a, b *ResolverResult) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
aIsNil := a == nil
bIsNil := b == nil
if aIsNil != bIsNil {
t.Fatalf("Expected aIsNil and bIsNil to be the same, got %v vs %v", aIsNil, bIsNil)
}
if aIsNil && bIsNil {
return
}
aHasAfterBuild := a.AfterBuild == nil
bHasAfterBuild := b.AfterBuild == nil
if aHasAfterBuild != bHasAfterBuild {
t.Fatalf("Expected aHasAfterBuild and bHasAfterBuild to be the same, got %v vs %v", aHasAfterBuild, bHasAfterBuild)
}
var aAfterBuild string
if a.AfterBuild != nil {
aAfterBuild = a.AfterBuild(dbx.NewExp("test")).Build(testDB.DB, a.Params)
}
var bAfterBuild string
if b.AfterBuild != nil {
bAfterBuild = b.AfterBuild(dbx.NewExp("test")).Build(testDB.DB, a.Params)
}
if aAfterBuild != bAfterBuild {
t.Fatalf("Expected bAfterBuild and bAfterBuild to be the same, got\n%s\nvs\n%s", aAfterBuild, bAfterBuild)
}
var aMultiMatchSubQuery string
if a.MultiMatchSubQuery != nil {
aMultiMatchSubQuery = a.MultiMatchSubQuery.Build(testDB.DB, a.Params)
}
var bMultiMatchSubQuery string
if b.MultiMatchSubQuery != nil {
bMultiMatchSubQuery = b.MultiMatchSubQuery.Build(testDB.DB, b.Params)
}
if aMultiMatchSubQuery != bMultiMatchSubQuery {
t.Fatalf("Expected bMultiMatchSubQuery and bMultiMatchSubQuery to be the same, got\n%s\nvs\n%s", aMultiMatchSubQuery, bMultiMatchSubQuery)
}
if a.NoCoalesce != b.NoCoalesce {
t.Fatalf("Expected NoCoalesce to match, got %v vs %v", a.NoCoalesce, b.NoCoalesce)
}
// loose placeholders replacement
var aResolved = a.Identifier
for k, v := range a.Params {
aResolved = strings.ReplaceAll(aResolved, "{:"+k+"}", fmt.Sprintf("%v", v))
}
var bResolved = b.Identifier
for k, v := range b.Params {
bResolved = strings.ReplaceAll(bResolved, "{:"+k+"}", fmt.Sprintf("%v", v))
}
if aResolved != bResolved {
t.Fatalf("Expected resolved identifiers to match, got\n%s\nvs\n%s", aResolved, bResolved)
}
}