Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
718
tools/search/filter.go
Normal file
718
tools/search/filter.go
Normal file
|
@ -0,0 +1,718 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ganigeorgiev/fexpr"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// FilterData is a filter expression string following the `fexpr` package grammar.
|
||||
//
|
||||
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
|
||||
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var filter FilterData = "id = null || (name = 'test' && status = true) || (total >= {:min} && total <= {:max})"
|
||||
// resolver := search.NewSimpleFieldResolver("id", "name", "status")
|
||||
// expr, err := filter.BuildExpr(resolver, dbx.Params{"min": 100, "max": 200})
|
||||
type FilterData string
|
||||
|
||||
// parsedFilterData holds a cache with previously parsed filter data expressions
|
||||
// (initialized with some preallocated empty data map)
|
||||
var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
|
||||
|
||||
// BuildExpr parses the current filter data and returns a new db WHERE expression.
|
||||
//
|
||||
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
|
||||
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
|
||||
//
|
||||
// The parsed expressions are limited up to DefaultFilterExprLimit.
|
||||
// Use [FilterData.BuildExprWithLimit] if you want to set a custom limit.
|
||||
func (f FilterData) BuildExpr(
|
||||
fieldResolver FieldResolver,
|
||||
placeholderReplacements ...dbx.Params,
|
||||
) (dbx.Expression, error) {
|
||||
return f.BuildExprWithLimit(fieldResolver, DefaultFilterExprLimit, placeholderReplacements...)
|
||||
}
|
||||
|
||||
// BuildExpr parses the current filter data and returns a new db WHERE expression.
|
||||
//
|
||||
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
|
||||
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
|
||||
func (f FilterData) BuildExprWithLimit(
|
||||
fieldResolver FieldResolver,
|
||||
maxExpressions int,
|
||||
placeholderReplacements ...dbx.Params,
|
||||
) (dbx.Expression, error) {
|
||||
raw := string(f)
|
||||
|
||||
// replace the placeholder params in the raw string filter
|
||||
for _, p := range placeholderReplacements {
|
||||
for key, value := range p {
|
||||
var replacement string
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
replacement = "null"
|
||||
case bool, float64, float32, int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
|
||||
replacement = cast.ToString(v)
|
||||
default:
|
||||
replacement = cast.ToString(v)
|
||||
|
||||
// try to json serialize as fallback
|
||||
if replacement == "" {
|
||||
raw, _ := json.Marshal(v)
|
||||
replacement = string(raw)
|
||||
}
|
||||
|
||||
replacement = strconv.Quote(replacement)
|
||||
}
|
||||
raw = strings.ReplaceAll(raw, "{:"+key+"}", replacement)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := raw + "/" + strconv.Itoa(maxExpressions)
|
||||
|
||||
if data, ok := parsedFilterData.GetOk(cacheKey); ok {
|
||||
return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
|
||||
}
|
||||
|
||||
data, err := fexpr.Parse(raw)
|
||||
if err != nil {
|
||||
// depending on the users demand we may allow empty expressions
|
||||
// (aka. expressions consisting only of whitespaces or comments)
|
||||
// but for now disallow them as it seems unnecessary
|
||||
// if errors.Is(err, fexpr.ErrEmpty) {
|
||||
// return dbx.NewExp("1=1"), nil
|
||||
// }
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// store in cache
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
parsedFilterData.SetIfLessThanLimit(cacheKey, data, 500)
|
||||
|
||||
return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
|
||||
}
|
||||
|
||||
func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver, maxExpressions *int) (dbx.Expression, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fexpr.ErrEmpty
|
||||
}
|
||||
|
||||
result := &concatExpr{separator: " "}
|
||||
|
||||
for _, group := range data {
|
||||
var expr dbx.Expression
|
||||
var exprErr error
|
||||
|
||||
switch item := group.Item.(type) {
|
||||
case fexpr.Expr:
|
||||
if *maxExpressions <= 0 {
|
||||
return nil, ErrFilterExprLimit
|
||||
}
|
||||
|
||||
*maxExpressions--
|
||||
|
||||
expr, exprErr = resolveTokenizedExpr(item, fieldResolver)
|
||||
case fexpr.ExprGroup:
|
||||
expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver, maxExpressions)
|
||||
case []fexpr.ExprGroup:
|
||||
expr, exprErr = buildParsedFilterExpr(item, fieldResolver, maxExpressions)
|
||||
default:
|
||||
exprErr = errors.New("unsupported expression item")
|
||||
}
|
||||
|
||||
if exprErr != nil {
|
||||
return nil, exprErr
|
||||
}
|
||||
|
||||
if len(result.parts) > 0 {
|
||||
var op string
|
||||
if group.Join == fexpr.JoinOr {
|
||||
op = "OR"
|
||||
} else {
|
||||
op = "AND"
|
||||
}
|
||||
result.parts = append(result.parts, &opExpr{op})
|
||||
}
|
||||
|
||||
result.parts = append(result.parts, expr)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
lResult, lErr := resolveToken(expr.Left, fieldResolver)
|
||||
if lErr != nil || lResult.Identifier == "" {
|
||||
return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr)
|
||||
}
|
||||
|
||||
rResult, rErr := resolveToken(expr.Right, fieldResolver)
|
||||
if rErr != nil || rResult.Identifier == "" {
|
||||
return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr)
|
||||
}
|
||||
|
||||
return buildResolversExpr(lResult, expr.Op, rResult)
|
||||
}
|
||||
|
||||
func buildResolversExpr(
|
||||
left *ResolverResult,
|
||||
op fexpr.SignOp,
|
||||
right *ResolverResult,
|
||||
) (dbx.Expression, error) {
|
||||
var expr dbx.Expression
|
||||
|
||||
switch op {
|
||||
case fexpr.SignEq, fexpr.SignAnyEq:
|
||||
expr = resolveEqualExpr(true, left, right)
|
||||
case fexpr.SignNeq, fexpr.SignAnyNeq:
|
||||
expr = resolveEqualExpr(false, left, right)
|
||||
case fexpr.SignLike, fexpr.SignAnyLike:
|
||||
// the right side is a column and therefor wrap it with "%" for contains like behavior
|
||||
if len(right.Params) == 0 {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
|
||||
} else {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
|
||||
}
|
||||
case fexpr.SignNlike, fexpr.SignAnyNlike:
|
||||
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
|
||||
if len(right.Params) == 0 {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
|
||||
} else {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
|
||||
}
|
||||
case fexpr.SignLt, fexpr.SignAnyLt:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s < %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignLte, fexpr.SignAnyLte:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s <= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignGt, fexpr.SignAnyGt:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s > %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignGte, fexpr.SignAnyGte:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s >= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
}
|
||||
|
||||
if expr == nil {
|
||||
return nil, fmt.Errorf("unknown expression operator %q", op)
|
||||
}
|
||||
|
||||
// multi-match expressions
|
||||
if !isAnyMatchOp(op) {
|
||||
if left.MultiMatchSubQuery != nil && right.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsManyExpr{
|
||||
left: left,
|
||||
right: right,
|
||||
op: op,
|
||||
}
|
||||
|
||||
expr = dbx.Enclose(dbx.And(expr, mm))
|
||||
} else if left.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsOneExpr{
|
||||
noCoalesce: left.NoCoalesce,
|
||||
subQuery: left.MultiMatchSubQuery,
|
||||
op: op,
|
||||
otherOperand: right,
|
||||
}
|
||||
|
||||
expr = dbx.Enclose(dbx.And(expr, mm))
|
||||
} else if right.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsOneExpr{
|
||||
noCoalesce: right.NoCoalesce,
|
||||
subQuery: right.MultiMatchSubQuery,
|
||||
op: op,
|
||||
otherOperand: left,
|
||||
inverse: true,
|
||||
}
|
||||
|
||||
expr = dbx.Enclose(dbx.And(expr, mm))
|
||||
}
|
||||
}
|
||||
|
||||
if left.AfterBuild != nil {
|
||||
expr = left.AfterBuild(expr)
|
||||
}
|
||||
|
||||
if right.AfterBuild != nil {
|
||||
expr = right.AfterBuild(expr)
|
||||
}
|
||||
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
var normalizedIdentifiers = map[string]string{
|
||||
// if `null` field is missing, treat `null` identifier as NULL token
|
||||
"null": "NULL",
|
||||
// if `true` field is missing, treat `true` identifier as TRUE token
|
||||
"true": "1",
|
||||
// if `false` field is missing, treat `false` identifier as FALSE token
|
||||
"false": "0",
|
||||
}
|
||||
|
||||
func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) {
|
||||
switch token.Type {
|
||||
case fexpr.TokenIdentifier:
|
||||
// check for macros
|
||||
// ---
|
||||
if macroFunc, ok := identifierMacros[token.Literal]; ok {
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
|
||||
macroValue, err := macroFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: macroValue},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// custom resolver
|
||||
// ---
|
||||
result, err := fieldResolver.Resolve(token.Literal)
|
||||
if err != nil || result.Identifier == "" {
|
||||
for k, v := range normalizedIdentifiers {
|
||||
if strings.EqualFold(k, token.Literal) {
|
||||
return &ResolverResult{Identifier: v}, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, err
|
||||
case fexpr.TokenText:
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: token.Literal},
|
||||
}, nil
|
||||
case fexpr.TokenNumber:
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: cast.ToFloat64(token.Literal)},
|
||||
}, nil
|
||||
case fexpr.TokenFunction:
|
||||
fn, ok := TokenFunctions[token.Literal]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown function %q", token.Literal)
|
||||
}
|
||||
|
||||
args, _ := token.Meta.([]fexpr.Token)
|
||||
return fn(func(argToken fexpr.Token) (*ResolverResult, error) {
|
||||
return resolveToken(argToken, fieldResolver)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported token type %q", token.Type)
|
||||
}
|
||||
|
||||
// Resolves = and != expressions in an attempt to minimize the COALESCE
|
||||
// usage and to gracefully handle null vs empty string normalizations.
|
||||
//
|
||||
// The expression `a = "" OR a is null` tends to perform better than
|
||||
// `COALESCE(a, "") = ""` since the direct match can be accomplished
|
||||
// with a seek while the COALESCE will induce a table scan.
|
||||
func resolveEqualExpr(equal bool, left, right *ResolverResult) dbx.Expression {
|
||||
isLeftEmpty := isEmptyIdentifier(left) || (len(left.Params) == 1 && hasEmptyParamValue(left))
|
||||
isRightEmpty := isEmptyIdentifier(right) || (len(right.Params) == 1 && hasEmptyParamValue(right))
|
||||
|
||||
equalOp := "="
|
||||
nullEqualOp := "IS"
|
||||
concatOp := "OR"
|
||||
nullExpr := "IS NULL"
|
||||
if !equal {
|
||||
// always use `IS NOT` instead of `!=` because direct non-equal comparisons
|
||||
// to nullable column values that are actually NULL yields to NULL instead of TRUE, eg.:
|
||||
// `'example' != nullableColumn` -> NULL even if nullableColumn row value is NULL
|
||||
equalOp = "IS NOT"
|
||||
nullEqualOp = equalOp
|
||||
concatOp = "AND"
|
||||
nullExpr = "IS NOT NULL"
|
||||
}
|
||||
|
||||
// no coalesce (eg. compare to a json field)
|
||||
// a IS b
|
||||
// a IS NOT b
|
||||
if left.NoCoalesce || right.NoCoalesce {
|
||||
return dbx.NewExp(
|
||||
fmt.Sprintf("%s %s %s", left.Identifier, nullEqualOp, right.Identifier),
|
||||
mergeParams(left.Params, right.Params),
|
||||
)
|
||||
}
|
||||
|
||||
// both operands are empty
|
||||
if isLeftEmpty && isRightEmpty {
|
||||
return dbx.NewExp(fmt.Sprintf("'' %s ''", equalOp), mergeParams(left.Params, right.Params))
|
||||
}
|
||||
|
||||
// direct compare since at least one of the operands is known to be non-empty
|
||||
// eg. a = 'example'
|
||||
if isKnownNonEmptyIdentifier(left) || isKnownNonEmptyIdentifier(right) {
|
||||
leftIdentifier := left.Identifier
|
||||
if isLeftEmpty {
|
||||
leftIdentifier = "''"
|
||||
}
|
||||
rightIdentifier := right.Identifier
|
||||
if isRightEmpty {
|
||||
rightIdentifier = "''"
|
||||
}
|
||||
return dbx.NewExp(
|
||||
fmt.Sprintf("%s %s %s", leftIdentifier, equalOp, rightIdentifier),
|
||||
mergeParams(left.Params, right.Params),
|
||||
)
|
||||
}
|
||||
|
||||
// "" = b OR b IS NULL
|
||||
// "" IS NOT b AND b IS NOT NULL
|
||||
if isLeftEmpty {
|
||||
return dbx.NewExp(
|
||||
fmt.Sprintf("('' %s %s %s %s %s)", equalOp, right.Identifier, concatOp, right.Identifier, nullExpr),
|
||||
mergeParams(left.Params, right.Params),
|
||||
)
|
||||
}
|
||||
|
||||
// a = "" OR a IS NULL
|
||||
// a IS NOT "" AND a IS NOT NULL
|
||||
if isRightEmpty {
|
||||
return dbx.NewExp(
|
||||
fmt.Sprintf("(%s %s '' %s %s %s)", left.Identifier, equalOp, concatOp, left.Identifier, nullExpr),
|
||||
mergeParams(left.Params, right.Params),
|
||||
)
|
||||
}
|
||||
|
||||
// fallback to a COALESCE comparison
|
||||
return dbx.NewExp(
|
||||
fmt.Sprintf(
|
||||
"COALESCE(%s, '') %s COALESCE(%s, '')",
|
||||
left.Identifier,
|
||||
equalOp,
|
||||
right.Identifier,
|
||||
),
|
||||
mergeParams(left.Params, right.Params),
|
||||
)
|
||||
}
|
||||
|
||||
func hasEmptyParamValue(result *ResolverResult) bool {
|
||||
for _, p := range result.Params {
|
||||
switch v := p.(type) {
|
||||
case nil:
|
||||
return true
|
||||
case string:
|
||||
if v == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isKnownNonEmptyIdentifier(result *ResolverResult) bool {
|
||||
switch strings.ToLower(result.Identifier) {
|
||||
case "1", "0", "false", `true`:
|
||||
return true
|
||||
}
|
||||
|
||||
return len(result.Params) > 0 && !hasEmptyParamValue(result) && !isEmptyIdentifier(result)
|
||||
}
|
||||
|
||||
func isEmptyIdentifier(result *ResolverResult) bool {
|
||||
switch strings.ToLower(result.Identifier) {
|
||||
case "", "null", "''", `""`, "``":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isAnyMatchOp(op fexpr.SignOp) bool {
|
||||
switch op {
|
||||
case
|
||||
fexpr.SignAnyEq,
|
||||
fexpr.SignAnyNeq,
|
||||
fexpr.SignAnyLike,
|
||||
fexpr.SignAnyNlike,
|
||||
fexpr.SignAnyLt,
|
||||
fexpr.SignAnyLte,
|
||||
fexpr.SignAnyGt,
|
||||
fexpr.SignAnyGte:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// mergeParams returns new dbx.Params where each provided params item
|
||||
// is merged in the order they are specified.
|
||||
func mergeParams(params ...dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
for _, p := range params {
|
||||
for k, v := range p {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// @todo consider adding support for custom single character wildcard
|
||||
//
|
||||
// wrapLikeParams wraps each provided param value string with `%`
|
||||
// if the param doesn't contain an explicit wildcard (`%`) character already.
|
||||
func wrapLikeParams(params dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
for k, v := range params {
|
||||
vStr := cast.ToString(v)
|
||||
if !containsUnescapedChar(vStr, '%') {
|
||||
// note: this is done to minimize the breaking changes and to preserve the original autoescape behavior
|
||||
vStr = escapeUnescapedChars(vStr, '\\', '%', '_')
|
||||
vStr = "%" + vStr + "%"
|
||||
}
|
||||
result[k] = vStr
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func escapeUnescapedChars(str string, escapeChars ...rune) string {
|
||||
rs := []rune(str)
|
||||
total := len(rs)
|
||||
result := make([]rune, 0, total)
|
||||
|
||||
var match bool
|
||||
|
||||
for i := total - 1; i >= 0; i-- {
|
||||
if match {
|
||||
// check if already escaped
|
||||
if rs[i] != '\\' {
|
||||
result = append(result, '\\')
|
||||
}
|
||||
match = false
|
||||
} else {
|
||||
for _, ec := range escapeChars {
|
||||
if rs[i] == ec {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, rs[i])
|
||||
|
||||
// in case the matching char is at the beginning
|
||||
if i == 0 && match {
|
||||
result = append(result, '\\')
|
||||
}
|
||||
}
|
||||
|
||||
// reverse
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func containsUnescapedChar(str string, ch rune) bool {
|
||||
var prev rune
|
||||
|
||||
for _, c := range str {
|
||||
if c == ch && prev != '\\' {
|
||||
return true
|
||||
}
|
||||
|
||||
if c == '\\' && prev == '\\' {
|
||||
prev = rune(0) // reset escape sequence
|
||||
} else {
|
||||
prev = c
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*opExpr)(nil)
|
||||
|
||||
// opExpr defines an expression that contains a raw sql operator string.
|
||||
type opExpr struct {
|
||||
op string
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *opExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
return e.op
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*concatExpr)(nil)
|
||||
|
||||
// concatExpr defines an expression that concatenates multiple
|
||||
// other expressions with a specified separator.
|
||||
type concatExpr struct {
|
||||
separator string
|
||||
parts []dbx.Expression
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if len(e.parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
stringParts := make([]string, 0, len(e.parts))
|
||||
|
||||
for _, p := range e.parts {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if sql := p.Build(db, params); sql != "" {
|
||||
stringParts = append(stringParts, sql)
|
||||
}
|
||||
}
|
||||
|
||||
// skip extra parenthesis for single concat expression
|
||||
if len(stringParts) == 1 &&
|
||||
// check for already concatenated raw/plain expressions
|
||||
!strings.Contains(strings.ToUpper(stringParts[0]), " AND ") &&
|
||||
!strings.Contains(strings.ToUpper(stringParts[0]), " OR ") {
|
||||
return stringParts[0]
|
||||
}
|
||||
|
||||
return "(" + strings.Join(stringParts, e.separator) + ")"
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*manyVsManyExpr)(nil)
|
||||
|
||||
// manyVsManyExpr constructs a multi-match many<->many db where expression.
|
||||
//
|
||||
// Expects leftSubQuery and rightSubQuery to return a subquery with a
|
||||
// single "multiMatchValue" column.
|
||||
type manyVsManyExpr struct {
|
||||
left *ResolverResult
|
||||
right *ResolverResult
|
||||
op fexpr.SignOp
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if e.left.MultiMatchSubQuery == nil || e.right.MultiMatchSubQuery == nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
lAlias := "__ml" + security.PseudorandomString(8)
|
||||
rAlias := "__mr" + security.PseudorandomString(8)
|
||||
|
||||
whereExpr, buildErr := buildResolversExpr(
|
||||
&ResolverResult{
|
||||
NoCoalesce: e.left.NoCoalesce,
|
||||
Identifier: "[[" + lAlias + ".multiMatchValue]]",
|
||||
},
|
||||
e.op,
|
||||
&ResolverResult{
|
||||
NoCoalesce: e.right.NoCoalesce,
|
||||
Identifier: "[[" + rAlias + ".multiMatchValue]]",
|
||||
// note: the AfterBuild needs to be handled only once and it
|
||||
// doesn't matter whether it is applied on the left or right subquery operand
|
||||
AfterBuild: dbx.Not, // inverse for the not-exist expression
|
||||
},
|
||||
)
|
||||
|
||||
if buildErr != nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} LEFT JOIN (%s) {{%s}} WHERE %s)",
|
||||
e.left.MultiMatchSubQuery.Build(db, params),
|
||||
lAlias,
|
||||
e.right.MultiMatchSubQuery.Build(db, params),
|
||||
rAlias,
|
||||
whereExpr.Build(db, params),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*manyVsOneExpr)(nil)
|
||||
|
||||
// manyVsOneExpr constructs a multi-match many<->one db where expression.
|
||||
//
|
||||
// Expects subQuery to return a subquery with a single "multiMatchValue" column.
|
||||
//
|
||||
// You can set inverse=false to reverse the condition sides (aka. one<->many).
|
||||
type manyVsOneExpr struct {
|
||||
otherOperand *ResolverResult
|
||||
subQuery dbx.Expression
|
||||
op fexpr.SignOp
|
||||
inverse bool
|
||||
noCoalesce bool
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if e.subQuery == nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
alias := "__sm" + security.PseudorandomString(8)
|
||||
|
||||
r1 := &ResolverResult{
|
||||
NoCoalesce: e.noCoalesce,
|
||||
Identifier: "[[" + alias + ".multiMatchValue]]",
|
||||
AfterBuild: dbx.Not, // inverse for the not-exist expression
|
||||
}
|
||||
|
||||
r2 := &ResolverResult{
|
||||
Identifier: e.otherOperand.Identifier,
|
||||
Params: e.otherOperand.Params,
|
||||
}
|
||||
|
||||
var whereExpr dbx.Expression
|
||||
var buildErr error
|
||||
|
||||
if e.inverse {
|
||||
whereExpr, buildErr = buildResolversExpr(r2, e.op, r1)
|
||||
} else {
|
||||
whereExpr, buildErr = buildResolversExpr(r1, e.op, r2)
|
||||
}
|
||||
|
||||
if buildErr != nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} WHERE %s)",
|
||||
e.subQuery.Build(db, params),
|
||||
alias,
|
||||
whereExpr.Build(db, params),
|
||||
)
|
||||
}
|
341
tools/search/filter_test.go
Normal file
341
tools/search/filter_test.go
Normal file
|
@ -0,0 +1,341 @@
|
|||
package search_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestFilterDataBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", `^test4_\w+$`, `^test5\.[\w\.\:]*\w+$`)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
filterData search.FilterData
|
||||
expectError bool
|
||||
expectPattern string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid format",
|
||||
"(test1 > 1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid operator",
|
||||
"test1 + 123",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"unknown field",
|
||||
"test1 = 'example' && unknown > 1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"simple expression",
|
||||
"test1 > 1",
|
||||
false,
|
||||
"[[test1]] > {:TEST}",
|
||||
},
|
||||
{
|
||||
"empty string vs null",
|
||||
"'' = null && null != ''",
|
||||
false,
|
||||
"('' = '' AND '' IS NOT '')",
|
||||
},
|
||||
{
|
||||
"like with 2 columns",
|
||||
"test1 ~ test2",
|
||||
false,
|
||||
"[[test1]] LIKE ('%' || [[test2]] || '%') ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"like with right column operand",
|
||||
"'lorem' ~ test1",
|
||||
false,
|
||||
"{:TEST} LIKE ('%' || [[test1]] || '%') ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"like with left column operand and text as right operand",
|
||||
"test1 ~ 'lorem'",
|
||||
false,
|
||||
"[[test1]] LIKE {:TEST} ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"not like with 2 columns",
|
||||
"test1 !~ test2",
|
||||
false,
|
||||
"[[test1]] NOT LIKE ('%' || [[test2]] || '%') ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"not like with right column operand",
|
||||
"'lorem' !~ test1",
|
||||
false,
|
||||
"{:TEST} NOT LIKE ('%' || [[test1]] || '%') ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"like with left column operand and text as right operand",
|
||||
"test1 !~ 'lorem'",
|
||||
false,
|
||||
"[[test1]] NOT LIKE {:TEST} ESCAPE '\\'",
|
||||
},
|
||||
{
|
||||
"nested json no coalesce",
|
||||
"test5.a = test5.b || test5.c != test5.d",
|
||||
false,
|
||||
"(JSON_EXTRACT([[test5]], '$.a') IS JSON_EXTRACT([[test5]], '$.b') OR JSON_EXTRACT([[test5]], '$.c') IS NOT JSON_EXTRACT([[test5]], '$.d'))",
|
||||
},
|
||||
{
|
||||
"macros",
|
||||
`
|
||||
test4_1 > @now &&
|
||||
test4_2 > @second &&
|
||||
test4_3 > @minute &&
|
||||
test4_4 > @hour &&
|
||||
test4_5 > @day &&
|
||||
test4_6 > @year &&
|
||||
test4_7 > @month &&
|
||||
test4_9 > @weekday &&
|
||||
test4_9 > @todayStart &&
|
||||
test4_10 > @todayEnd &&
|
||||
test4_11 > @monthStart &&
|
||||
test4_12 > @monthEnd &&
|
||||
test4_13 > @yearStart &&
|
||||
test4_14 > @yearEnd
|
||||
`,
|
||||
false,
|
||||
"([[test4_1]] > {:TEST} AND [[test4_2]] > {:TEST} AND [[test4_3]] > {:TEST} AND [[test4_4]] > {:TEST} AND [[test4_5]] > {:TEST} AND [[test4_6]] > {:TEST} AND [[test4_7]] > {:TEST} AND [[test4_9]] > {:TEST} AND [[test4_9]] > {:TEST} AND [[test4_10]] > {:TEST} AND [[test4_11]] > {:TEST} AND [[test4_12]] > {:TEST} AND [[test4_13]] > {:TEST} AND [[test4_14]] > {:TEST})",
|
||||
},
|
||||
{
|
||||
"complex expression",
|
||||
"((test1 > 1) || (test2 != 2)) && test3 ~ '%%example' && test4_sub = null",
|
||||
false,
|
||||
"(([[test1]] > {:TEST} OR [[test2]] IS NOT {:TEST}) AND [[test3]] LIKE {:TEST} ESCAPE '\\' AND ([[test4_sub]] = '' OR [[test4_sub]] IS NULL))",
|
||||
},
|
||||
{
|
||||
"combination of special literals (null, true, false)",
|
||||
"test1=true && test2 != false && null = test3 || null != test4_sub",
|
||||
false,
|
||||
"([[test1]] = 1 AND [[test2]] IS NOT 0 AND ('' = [[test3]] OR [[test3]] IS NULL) OR ('' IS NOT [[test4_sub]] AND [[test4_sub]] IS NOT NULL))",
|
||||
},
|
||||
{
|
||||
"all operators",
|
||||
"(test1 = test2 || test2 != test3) && (test2 ~ 'example' || test2 !~ '%%abc') && 'switch1%%' ~ test1 && 'switch2' !~ test2 && test3 > 1 && test3 >= 0 && test3 <= 4 && 2 < 5",
|
||||
false,
|
||||
"((COALESCE([[test1]], '') = COALESCE([[test2]], '') OR COALESCE([[test2]], '') IS NOT COALESCE([[test3]], '')) AND ([[test2]] LIKE {:TEST} ESCAPE '\\' OR [[test2]] NOT LIKE {:TEST} ESCAPE '\\') AND {:TEST} LIKE ('%' || [[test1]] || '%') ESCAPE '\\' AND {:TEST} NOT LIKE ('%' || [[test2]] || '%') ESCAPE '\\' AND [[test3]] > {:TEST} AND [[test3]] >= {:TEST} AND [[test3]] <= {:TEST} AND {:TEST} < {:TEST})",
|
||||
},
|
||||
{
|
||||
"geoDistance function",
|
||||
"geoDistance(1,2,3,4) < 567",
|
||||
false,
|
||||
"(6371 * acos(cos(radians({:TEST})) * cos(radians({:TEST})) * cos(radians({:TEST}) - radians({:TEST})) + sin(radians({:TEST})) * sin(radians({:TEST})))) < {:TEST}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
expr, err := s.filterData.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
dummyDB := &dbx.DB{}
|
||||
|
||||
rawSql := expr.Build(dummyDB, dbx.Params{})
|
||||
|
||||
// replace TEST placeholder with .+ regex pattern
|
||||
expectPattern := strings.ReplaceAll(
|
||||
"^"+regexp.QuoteMeta(s.expectPattern)+"$",
|
||||
"TEST",
|
||||
`\w+`,
|
||||
)
|
||||
|
||||
pattern := regexp.MustCompile(expectPattern)
|
||||
if !pattern.MatchString(rawSql) {
|
||||
t.Fatalf("[%s] Pattern %v don't match with expression: \n%v", s.name, expectPattern, rawSql)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterDataBuildExprWithParams(t *testing.T) {
|
||||
// create a dummy db
|
||||
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := dbx.NewFromDB(sqlDB, "sqlite")
|
||||
|
||||
calledQueries := []string{}
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
|
||||
date, err := time.Parse("2006-01-02", "2023-01-01")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resolver := search.NewSimpleFieldResolver(`^test\w+$`)
|
||||
|
||||
filter := search.FilterData(`
|
||||
test1 = {:test1} ||
|
||||
test2 = {:test2} ||
|
||||
test3a = {:test3} ||
|
||||
test3b = {:test3} ||
|
||||
test4 = {:test4} ||
|
||||
test5 = {:test5} ||
|
||||
test6 = {:test6} ||
|
||||
test7 = {:test7} ||
|
||||
test8 = {:test8} ||
|
||||
test9 = {:test9} ||
|
||||
test10 = {:test10} ||
|
||||
test11 = {:test11} ||
|
||||
test12 = {:test12}
|
||||
`)
|
||||
|
||||
replacements := []dbx.Params{
|
||||
{"test1": true},
|
||||
{"test2": false},
|
||||
{"test3": 123.456},
|
||||
{"test4": nil},
|
||||
{"test5": "", "test6": "simple", "test7": `'single_quotes'`, "test8": `"double_quotes"`, "test9": `escape\"quote`},
|
||||
{"test10": date},
|
||||
{"test11": []string{"a", "b", `"quote`}},
|
||||
{"test12": map[string]any{"a": 123, "b": `quote"`}},
|
||||
}
|
||||
|
||||
expr, err := filter.BuildExpr(resolver, replacements...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Select().Where(expr).Build().Execute()
|
||||
|
||||
if len(calledQueries) != 1 {
|
||||
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
|
||||
}
|
||||
|
||||
expectedQuery := `SELECT * WHERE ([[test1]] = 1 OR [[test2]] = 0 OR [[test3a]] = 123.456 OR [[test3b]] = 123.456 OR ([[test4]] = '' OR [[test4]] IS NULL) OR [[test5]] = '""' OR [[test6]] = 'simple' OR [[test7]] = '''single_quotes''' OR [[test8]] = '"double_quotes"' OR [[test9]] = 'escape\\"quote' OR [[test10]] = '2023-01-01 00:00:00 +0000 UTC' OR [[test11]] = '["a","b","\\"quote"]' OR [[test12]] = '{"a":123,"b":"quote\\""}')`
|
||||
if expectedQuery != calledQueries[0] {
|
||||
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterDataBuildExprWithLimit(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver(`^\w+$`)
|
||||
|
||||
scenarios := []struct {
|
||||
limit int
|
||||
filter search.FilterData
|
||||
expectError bool
|
||||
}{
|
||||
{1, "1 = 1", false},
|
||||
{0, "1 = 1", true}, // new cache entry should be created
|
||||
{2, "1 = 1 || 1 = 1", false},
|
||||
{1, "1 = 1 || 1 = 1", true},
|
||||
{3, "1 = 1 || 1 = 1", false},
|
||||
{6, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", false},
|
||||
{5, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", true},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("limit_%d:%d", i, s.limit), func(t *testing.T) {
|
||||
_, err := s.filter.BuildExprWithLimit(resolver, s.limit)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLikeParamsWrapping(t *testing.T) {
|
||||
// create a dummy db
|
||||
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := dbx.NewFromDB(sqlDB, "sqlite")
|
||||
|
||||
calledQueries := []string{}
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
|
||||
resolver := search.NewSimpleFieldResolver(`^test\w+$`)
|
||||
|
||||
filter := search.FilterData(`
|
||||
test1 ~ {:p1} ||
|
||||
test2 ~ {:p2} ||
|
||||
test3 ~ {:p3} ||
|
||||
test4 ~ {:p4} ||
|
||||
test5 ~ {:p5} ||
|
||||
test6 ~ {:p6} ||
|
||||
test7 ~ {:p7} ||
|
||||
test8 ~ {:p8} ||
|
||||
test9 ~ {:p9} ||
|
||||
test10 ~ {:p10} ||
|
||||
test11 ~ {:p11} ||
|
||||
test12 ~ {:p12}
|
||||
`)
|
||||
|
||||
replacements := []dbx.Params{
|
||||
{"p1": `abc`},
|
||||
{"p2": `ab%c`},
|
||||
{"p3": `ab\%c`},
|
||||
{"p4": `%ab\%c`},
|
||||
{"p5": `ab\\%c`},
|
||||
{"p6": `ab\\\%c`},
|
||||
{"p7": `ab_c`},
|
||||
{"p8": `ab\_c`},
|
||||
{"p9": `%ab_c`},
|
||||
{"p10": `ab\c`},
|
||||
{"p11": `_ab\c_`},
|
||||
{"p12": `ab\c%`},
|
||||
}
|
||||
|
||||
expr, err := filter.BuildExpr(resolver, replacements...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Select().Where(expr).Build().Execute()
|
||||
|
||||
if len(calledQueries) != 1 {
|
||||
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
|
||||
}
|
||||
|
||||
expectedQuery := `SELECT * WHERE ([[test1]] LIKE '%abc%' ESCAPE '\' OR [[test2]] LIKE 'ab%c' ESCAPE '\' OR [[test3]] LIKE 'ab\\%c' ESCAPE '\' OR [[test4]] LIKE '%ab\\%c' ESCAPE '\' OR [[test5]] LIKE 'ab\\\\%c' ESCAPE '\' OR [[test6]] LIKE 'ab\\\\\\%c' ESCAPE '\' OR [[test7]] LIKE '%ab\_c%' ESCAPE '\' OR [[test8]] LIKE '%ab\\\_c%' ESCAPE '\' OR [[test9]] LIKE '%ab_c' ESCAPE '\' OR [[test10]] LIKE '%ab\\c%' ESCAPE '\' OR [[test11]] LIKE '%\_ab\\c\_%' ESCAPE '\' OR [[test12]] LIKE 'ab\\c%' ESCAPE '\')`
|
||||
if expectedQuery != calledQueries[0] {
|
||||
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
|
||||
}
|
||||
}
|
135
tools/search/identifier_macros.go
Normal file
135
tools/search/identifier_macros.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// note: used primarily for the tests
|
||||
var timeNow = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
var identifierMacros = map[string]func() (any, error){
|
||||
"@now": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
|
||||
d, err := types.ParseDateTime(today)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@now: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@yesterday": func() (any, error) {
|
||||
yesterday := timeNow().UTC().AddDate(0, 0, -1)
|
||||
|
||||
d, err := types.ParseDateTime(yesterday)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@yesterday: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@tomorrow": func() (any, error) {
|
||||
tomorrow := timeNow().UTC().AddDate(0, 0, 1)
|
||||
|
||||
d, err := types.ParseDateTime(tomorrow)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@tomorrow: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@second": func() (any, error) {
|
||||
return timeNow().UTC().Second(), nil
|
||||
},
|
||||
"@minute": func() (any, error) {
|
||||
return timeNow().UTC().Minute(), nil
|
||||
},
|
||||
"@hour": func() (any, error) {
|
||||
return timeNow().UTC().Hour(), nil
|
||||
},
|
||||
"@day": func() (any, error) {
|
||||
return timeNow().UTC().Day(), nil
|
||||
},
|
||||
"@month": func() (any, error) {
|
||||
return int(timeNow().UTC().Month()), nil
|
||||
},
|
||||
"@weekday": func() (any, error) {
|
||||
return int(timeNow().UTC().Weekday()), nil
|
||||
},
|
||||
"@year": func() (any, error) {
|
||||
return timeNow().UTC().Year(), nil
|
||||
},
|
||||
"@todayStart": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
start := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, time.UTC)
|
||||
|
||||
d, err := types.ParseDateTime(start)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@todayStart: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@todayEnd": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
|
||||
start := time.Date(today.Year(), today.Month(), today.Day(), 23, 59, 59, 999999999, time.UTC)
|
||||
|
||||
d, err := types.ParseDateTime(start)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@todayEnd: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@monthStart": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
start := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
d, err := types.ParseDateTime(start)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@monthStart: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@monthEnd": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
start := time.Date(today.Year(), today.Month(), 1, 23, 59, 59, 999999999, time.UTC)
|
||||
end := start.AddDate(0, 1, -1)
|
||||
|
||||
d, err := types.ParseDateTime(end)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@monthEnd: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@yearStart": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
start := time.Date(today.Year(), 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
d, err := types.ParseDateTime(start)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@yearStart: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
"@yearEnd": func() (any, error) {
|
||||
today := timeNow().UTC()
|
||||
end := time.Date(today.Year(), 12, 31, 23, 59, 59, 999999999, time.UTC)
|
||||
|
||||
d, err := types.ParseDateTime(end)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("@yearEnd: %w", err)
|
||||
}
|
||||
|
||||
return d.String(), nil
|
||||
},
|
||||
}
|
58
tools/search/identifier_macros_test.go
Normal file
58
tools/search/identifier_macros_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIdentifierMacros(t *testing.T) {
|
||||
originalTimeNow := timeNow
|
||||
|
||||
timeNow = func() time.Time {
|
||||
return time.Date(2023, 2, 3, 4, 5, 6, 7, time.UTC)
|
||||
}
|
||||
|
||||
testMacros := map[string]any{
|
||||
"@now": "2023-02-03 04:05:06.000Z",
|
||||
"@yesterday": "2023-02-02 04:05:06.000Z",
|
||||
"@tomorrow": "2023-02-04 04:05:06.000Z",
|
||||
"@second": 6,
|
||||
"@minute": 5,
|
||||
"@hour": 4,
|
||||
"@day": 3,
|
||||
"@month": 2,
|
||||
"@weekday": 5,
|
||||
"@year": 2023,
|
||||
"@todayStart": "2023-02-03 00:00:00.000Z",
|
||||
"@todayEnd": "2023-02-03 23:59:59.999Z",
|
||||
"@monthStart": "2023-02-01 00:00:00.000Z",
|
||||
"@monthEnd": "2023-02-28 23:59:59.999Z",
|
||||
"@yearStart": "2023-01-01 00:00:00.000Z",
|
||||
"@yearEnd": "2023-12-31 23:59:59.999Z",
|
||||
}
|
||||
|
||||
if len(testMacros) != len(identifierMacros) {
|
||||
t.Fatalf("Expected %d macros, got %d", len(testMacros), len(identifierMacros))
|
||||
}
|
||||
|
||||
for key, expected := range testMacros {
|
||||
t.Run(key, func(t *testing.T) {
|
||||
macro, ok := identifierMacros[key]
|
||||
if !ok {
|
||||
t.Fatalf("Missing macro %s", key)
|
||||
}
|
||||
|
||||
result, err := macro()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %q, got %q", expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// restore
|
||||
timeNow = originalTimeNow
|
||||
}
|
361
tools/search/provider.go
Normal file
361
tools/search/provider.go
Normal file
|
@ -0,0 +1,361 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultPerPage specifies the default number of returned search result items.
|
||||
DefaultPerPage int = 30
|
||||
|
||||
// DefaultFilterExprLimit specifies the default filter expressions limit.
|
||||
DefaultFilterExprLimit int = 200
|
||||
|
||||
// DefaultSortExprLimit specifies the default sort expressions limit.
|
||||
DefaultSortExprLimit int = 8
|
||||
|
||||
// MaxPerPage specifies the max allowed search result items returned in a single page.
|
||||
MaxPerPage int = 1000
|
||||
|
||||
// MaxFilterLength specifies the max allowed individual search filter parsable length.
|
||||
MaxFilterLength int = 3500
|
||||
|
||||
// MaxSortFieldLength specifies the max allowed individual sort field parsable length.
|
||||
MaxSortFieldLength int = 255
|
||||
)
|
||||
|
||||
// Common search errors.
|
||||
var (
|
||||
ErrEmptyQuery = errors.New("search query is not set")
|
||||
ErrSortExprLimit = errors.New("max sort expressions limit reached")
|
||||
ErrFilterExprLimit = errors.New("max filter expressions limit reached")
|
||||
ErrFilterLengthLimit = errors.New("max filter length limit reached")
|
||||
ErrSortFieldLengthLimit = errors.New("max sort field length limit reached")
|
||||
)
|
||||
|
||||
// URL search query params
|
||||
const (
|
||||
PageQueryParam string = "page"
|
||||
PerPageQueryParam string = "perPage"
|
||||
SortQueryParam string = "sort"
|
||||
FilterQueryParam string = "filter"
|
||||
SkipTotalQueryParam string = "skipTotal"
|
||||
)
|
||||
|
||||
// Result defines the returned search result structure.
|
||||
type Result struct {
|
||||
Items any `json:"items"`
|
||||
Page int `json:"page"`
|
||||
PerPage int `json:"perPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
// Provider represents a single configured search provider instance.
|
||||
type Provider struct {
|
||||
fieldResolver FieldResolver
|
||||
query *dbx.SelectQuery
|
||||
countCol string
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
page int
|
||||
perPage int
|
||||
skipTotal bool
|
||||
maxFilterExprLimit int
|
||||
maxSortExprLimit int
|
||||
}
|
||||
|
||||
// NewProvider initializes and returns a new search provider.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// baseQuery := db.Select("*").From("user")
|
||||
// fieldResolver := search.NewSimpleFieldResolver("id", "name")
|
||||
// models := []*YourDataStruct{}
|
||||
//
|
||||
// result, err := search.NewProvider(fieldResolver).
|
||||
// Query(baseQuery).
|
||||
// ParseAndExec("page=2&filter=id>0&sort=-email", &models)
|
||||
func NewProvider(fieldResolver FieldResolver) *Provider {
|
||||
return &Provider{
|
||||
fieldResolver: fieldResolver,
|
||||
countCol: "id",
|
||||
page: 1,
|
||||
perPage: DefaultPerPage,
|
||||
sort: []SortField{},
|
||||
filter: []FilterData{},
|
||||
maxFilterExprLimit: DefaultFilterExprLimit,
|
||||
maxSortExprLimit: DefaultSortExprLimit,
|
||||
}
|
||||
}
|
||||
|
||||
// MaxFilterExprLimit changes the default max allowed filter expressions.
|
||||
//
|
||||
// Note that currently the limit is applied individually for each separate filter.
|
||||
func (s *Provider) MaxFilterExprLimit(max int) *Provider {
|
||||
s.maxFilterExprLimit = max
|
||||
return s
|
||||
}
|
||||
|
||||
// MaxSortExprLimit changes the default max allowed sort expressions.
|
||||
func (s *Provider) MaxSortExprLimit(max int) *Provider {
|
||||
s.maxSortExprLimit = max
|
||||
return s
|
||||
}
|
||||
|
||||
// Query sets the base query that will be used to fetch the search items.
|
||||
func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
|
||||
s.query = query
|
||||
return s
|
||||
}
|
||||
|
||||
// SkipTotal changes the `skipTotal` field of the current search provider.
|
||||
func (s *Provider) SkipTotal(skipTotal bool) *Provider {
|
||||
s.skipTotal = skipTotal
|
||||
return s
|
||||
}
|
||||
|
||||
// CountCol allows changing the default column (id) that is used
|
||||
// to generate the COUNT SQL query statement.
|
||||
//
|
||||
// This field is ignored if skipTotal is true.
|
||||
func (s *Provider) CountCol(name string) *Provider {
|
||||
s.countCol = name
|
||||
return s
|
||||
}
|
||||
|
||||
// Page sets the `page` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `page` value is done during `Exec()`.
|
||||
func (s *Provider) Page(page int) *Provider {
|
||||
s.page = page
|
||||
return s
|
||||
}
|
||||
|
||||
// PerPage sets the `perPage` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `perPage` value is done during `Exec()`.
|
||||
func (s *Provider) PerPage(perPage int) *Provider {
|
||||
s.perPage = perPage
|
||||
return s
|
||||
}
|
||||
|
||||
// Sort sets the `sort` field of the current search provider.
|
||||
func (s *Provider) Sort(sort []SortField) *Provider {
|
||||
s.sort = sort
|
||||
return s
|
||||
}
|
||||
|
||||
// AddSort appends the provided SortField to the existing provider's sort field.
|
||||
func (s *Provider) AddSort(field SortField) *Provider {
|
||||
s.sort = append(s.sort, field)
|
||||
return s
|
||||
}
|
||||
|
||||
// Filter sets the `filter` field of the current search provider.
|
||||
func (s *Provider) Filter(filter []FilterData) *Provider {
|
||||
s.filter = filter
|
||||
return s
|
||||
}
|
||||
|
||||
// AddFilter appends the provided FilterData to the existing provider's filter field.
|
||||
func (s *Provider) AddFilter(filter FilterData) *Provider {
|
||||
if filter != "" {
|
||||
s.filter = append(s.filter, filter)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Parse parses the search query parameter from the provided query string
|
||||
// and assigns the found fields to the current search provider.
|
||||
//
|
||||
// The data from the "sort" and "filter" query parameters are appended
|
||||
// to the existing provider's `sort` and `filter` fields
|
||||
// (aka. using `AddSort` and `AddFilter`).
|
||||
func (s *Provider) Parse(urlQuery string) error {
|
||||
params, err := url.ParseQuery(urlQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if raw := params.Get(SkipTotalQueryParam); raw != "" {
|
||||
v, err := strconv.ParseBool(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.SkipTotal(v)
|
||||
}
|
||||
|
||||
if raw := params.Get(PageQueryParam); raw != "" {
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Page(v)
|
||||
}
|
||||
|
||||
if raw := params.Get(PerPageQueryParam); raw != "" {
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.PerPage(v)
|
||||
}
|
||||
|
||||
if raw := params.Get(SortQueryParam); raw != "" {
|
||||
for _, sortField := range ParseSortFromString(raw) {
|
||||
s.AddSort(sortField)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := params.Get(FilterQueryParam); raw != "" {
|
||||
s.AddFilter(FilterData(raw))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes the search provider and fills/scans
|
||||
// the provided `items` slice with the found models.
|
||||
func (s *Provider) Exec(items any) (*Result, error) {
|
||||
if s.query == nil {
|
||||
return nil, ErrEmptyQuery
|
||||
}
|
||||
|
||||
// shallow clone the provider's query
|
||||
modelsQuery := *s.query
|
||||
|
||||
// build filters
|
||||
for _, f := range s.filter {
|
||||
if len(f) > MaxFilterLength {
|
||||
return nil, ErrFilterLengthLimit
|
||||
}
|
||||
expr, err := f.BuildExprWithLimit(s.fieldResolver, s.maxFilterExprLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != nil {
|
||||
modelsQuery.AndWhere(expr)
|
||||
}
|
||||
}
|
||||
|
||||
// apply sorting
|
||||
if len(s.sort) > s.maxSortExprLimit {
|
||||
return nil, ErrSortExprLimit
|
||||
}
|
||||
for _, sortField := range s.sort {
|
||||
if len(sortField.Name) > MaxSortFieldLength {
|
||||
return nil, ErrSortFieldLengthLimit
|
||||
}
|
||||
expr, err := sortField.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != "" {
|
||||
// ensure that _rowid_ expressions are always prefixed with the first FROM table
|
||||
if sortField.Name == rowidSortKey && !strings.Contains(expr, ".") {
|
||||
queryInfo := modelsQuery.Info()
|
||||
if len(queryInfo.From) > 0 {
|
||||
expr = "[[" + inflector.Columnify(queryInfo.From[0]) + "]]." + expr
|
||||
}
|
||||
}
|
||||
|
||||
modelsQuery.AndOrderBy(expr)
|
||||
}
|
||||
}
|
||||
|
||||
// apply field resolver query modifications (if any)
|
||||
if err := s.fieldResolver.UpdateQuery(&modelsQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// normalize page
|
||||
if s.page <= 0 {
|
||||
s.page = 1
|
||||
}
|
||||
|
||||
// normalize perPage
|
||||
if s.perPage <= 0 {
|
||||
s.perPage = DefaultPerPage
|
||||
} else if s.perPage > MaxPerPage {
|
||||
s.perPage = MaxPerPage
|
||||
}
|
||||
|
||||
// negative value to differentiate from the zero default
|
||||
totalCount := -1
|
||||
totalPages := -1
|
||||
|
||||
// prepare a count query from the base one
|
||||
countQuery := modelsQuery // shallow clone
|
||||
countExec := func() error {
|
||||
queryInfo := countQuery.Info()
|
||||
countCol := s.countCol
|
||||
if len(queryInfo.From) > 0 {
|
||||
countCol = queryInfo.From[0] + "." + countCol
|
||||
}
|
||||
|
||||
// note: countQuery is shallow cloned and slice/map in-place modifications should be avoided
|
||||
err := countQuery.Distinct(false).
|
||||
Select("COUNT(DISTINCT [[" + countCol + "]])").
|
||||
OrderBy( /* reset */ ).
|
||||
Row(&totalCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalPages = int(math.Ceil(float64(totalCount) / float64(s.perPage)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// apply pagination to the original query and fetch the models
|
||||
modelsExec := func() error {
|
||||
modelsQuery.Limit(int64(s.perPage))
|
||||
modelsQuery.Offset(int64(s.perPage * (s.page - 1)))
|
||||
|
||||
return modelsQuery.All(items)
|
||||
}
|
||||
|
||||
if !s.skipTotal {
|
||||
// execute the 2 queries concurrently
|
||||
errg := new(errgroup.Group)
|
||||
errg.SetLimit(2)
|
||||
errg.Go(countExec)
|
||||
errg.Go(modelsExec)
|
||||
if err := errg.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := modelsExec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result := &Result{
|
||||
Page: s.page,
|
||||
PerPage: s.perPage,
|
||||
TotalItems: totalCount,
|
||||
TotalPages: totalPages,
|
||||
Items: items,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ParseAndExec is a short convenient method to trigger both
|
||||
// `Parse()` and `Exec()` in a single call.
|
||||
func (s *Provider) ParseAndExec(urlQuery string, modelsSlice any) (*Result, error) {
|
||||
if err := s.Parse(urlQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Exec(modelsSlice)
|
||||
}
|
794
tools/search/provider_test.go
Normal file
794
tools/search/provider_test.go
Normal file
|
@ -0,0 +1,794 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r)
|
||||
|
||||
if p.page != 1 {
|
||||
t.Fatalf("Expected page %d, got %d", 1, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != DefaultPerPage {
|
||||
t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
|
||||
}
|
||||
|
||||
if p.maxFilterExprLimit != DefaultFilterExprLimit {
|
||||
t.Fatalf("Expected maxFilterExprLimit %d, got %d", DefaultFilterExprLimit, p.maxFilterExprLimit)
|
||||
}
|
||||
|
||||
if p.maxSortExprLimit != DefaultSortExprLimit {
|
||||
t.Fatalf("Expected maxSortExprLimit %d, got %d", DefaultSortExprLimit, p.maxSortExprLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxFilterExprLimit(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{})
|
||||
|
||||
testVals := []int{0, -10, 10}
|
||||
|
||||
for _, val := range testVals {
|
||||
t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
|
||||
p.MaxFilterExprLimit(val)
|
||||
|
||||
if p.maxFilterExprLimit != val {
|
||||
t.Fatalf("Expected maxFilterExprLimit to change to %d, got %d", val, p.maxFilterExprLimit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxSortExprLimit(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{})
|
||||
|
||||
testVals := []int{0, -10, 10}
|
||||
|
||||
for _, val := range testVals {
|
||||
t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
|
||||
p.MaxSortExprLimit(val)
|
||||
|
||||
if p.maxSortExprLimit != val {
|
||||
t.Fatalf("Expected maxSortExprLimit to change to %d, got %d", val, p.maxSortExprLimit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderQuery(t *testing.T) {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
querySql := query.Build().SQL()
|
||||
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Query(query)
|
||||
|
||||
expected := p.query.Build().SQL()
|
||||
|
||||
if querySql != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, querySql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderSkipTotal(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{})
|
||||
|
||||
if p.skipTotal {
|
||||
t.Fatalf("Expected the default skipTotal to be %v, got %v", false, p.skipTotal)
|
||||
}
|
||||
|
||||
p.SkipTotal(true)
|
||||
|
||||
if !p.skipTotal {
|
||||
t.Fatalf("Expected skipTotal to change to %v, got %v", true, p.skipTotal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderCountCol(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{})
|
||||
|
||||
if p.countCol != "id" {
|
||||
t.Fatalf("Expected the default countCol to be %s, got %s", "id", p.countCol)
|
||||
}
|
||||
|
||||
p.CountCol("test")
|
||||
|
||||
if p.countCol != "test" {
|
||||
t.Fatalf("Expected colCount to change to %s, got %s", "test", p.countCol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Page(10)
|
||||
|
||||
if p.page != 10 {
|
||||
t.Fatalf("Expected page %v, got %v", 10, p.page)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPerPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).PerPage(456)
|
||||
|
||||
if p.perPage != 456 {
|
||||
t.Fatalf("Expected perPage %v, got %v", 456, p.perPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderSort(t *testing.T) {
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Sort(initialSort).
|
||||
AddSort(SortField{"test3", SortDesc})
|
||||
|
||||
encoded, _ := json.Marshal(p.sort)
|
||||
expected := `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"test3","direction":"DESC"}]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected sort %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFilter(t *testing.T) {
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Filter(initialFilter).
|
||||
AddFilter("test3")
|
||||
|
||||
encoded, _ := json.Marshal(p.filter)
|
||||
expected := `["test1","test2","test3"]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected filter %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParse(t *testing.T) {
|
||||
initialPage := 2
|
||||
initialPerPage := 123
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
|
||||
scenarios := []struct {
|
||||
query string
|
||||
expectError bool
|
||||
expectPage int
|
||||
expectPerPage int
|
||||
expectSort string
|
||||
expectFilter string
|
||||
}{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
false,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid query
|
||||
{
|
||||
"invalid;",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid page
|
||||
{
|
||||
"page=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid perPage
|
||||
{
|
||||
"perPage=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// valid query parameters
|
||||
{
|
||||
"page=3&perPage=456&filter=test3&sort=-a,b,+c&other=123",
|
||||
false,
|
||||
3,
|
||||
456,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"a","direction":"DESC"},{"name":"b","direction":"ASC"},{"name":"c","direction":"ASC"}]`,
|
||||
`["test1","test2","test3"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.query), func(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Page(initialPage).
|
||||
PerPage(initialPerPage).
|
||||
Sort(initialSort).
|
||||
Filter(initialFilter)
|
||||
|
||||
err := p.Parse(s.query)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if p.page != s.expectPage {
|
||||
t.Fatalf("Expected page %v, got %v", s.expectPage, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != s.expectPerPage {
|
||||
t.Fatalf("Expected perPage %v, got %v", s.expectPerPage, p.perPage)
|
||||
}
|
||||
|
||||
encodedSort, _ := json.Marshal(p.sort)
|
||||
if string(encodedSort) != s.expectSort {
|
||||
t.Fatalf("Expected sort %v, got \n%v", s.expectSort, string(encodedSort))
|
||||
}
|
||||
|
||||
encodedFilter, _ := json.Marshal(p.filter)
|
||||
if string(encodedFilter) != s.expectFilter {
|
||||
t.Fatalf("Expected filter %v, got \n%v", s.expectFilter, string(encodedFilter))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecEmptyQuery(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{}).
|
||||
Query(nil)
|
||||
|
||||
_, err := p.Exec(&[]testTableStruct{})
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error with empty query, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
skipTotal bool
|
||||
expectError bool
|
||||
expectResult string
|
||||
expectQueries []string
|
||||
}{
|
||||
{
|
||||
"page normalization",
|
||||
-1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":10,"totalItems":2,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
{
|
||||
"perPage normalization",
|
||||
10,
|
||||
0, // fallback to default
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"items":[],"page":10,"perPage":30,"totalItems":2,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30 OFFSET 270",
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid sort field",
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"unknown", SortAsc}},
|
||||
[]FilterData{},
|
||||
false,
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"invalid filter",
|
||||
1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{"test2 = 'test2.1'", "invalid"},
|
||||
false,
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid sort and filter fields",
|
||||
1,
|
||||
5555, // will be limited by MaxPerPage
|
||||
[]SortField{{"test2", SortDesc}},
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
false,
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2)",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
|
||||
},
|
||||
},
|
||||
{
|
||||
"valid sort and filter fields (skipTotal=1)",
|
||||
1,
|
||||
5555, // will be limited by MaxPerPage
|
||||
[]SortField{{"test2", SortDesc}},
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
true,
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
|
||||
},
|
||||
},
|
||||
{
|
||||
"valid sort and filter fields (zero results)",
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"test3", SortAsc}},
|
||||
[]FilterData{"test3 != ''"},
|
||||
false,
|
||||
false,
|
||||
`{"items":[],"page":1,"perPage":10,"totalItems":0,"totalPages":0}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL)))",
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
{
|
||||
"valid sort and filter fields (zero results; skipTotal=1)",
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"test3", SortAsc}},
|
||||
[]FilterData{"test3 != ''"},
|
||||
true,
|
||||
false,
|
||||
`{"items":[],"page":1,"perPage":10,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
{
|
||||
"pagination test",
|
||||
2,
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":2,"totalPages":2}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"pagination test (skipTotal=1)",
|
||||
2,
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
true,
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
p := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(s.page).
|
||||
PerPage(s.perPage).
|
||||
Sort(s.sort).
|
||||
SkipTotal(s.skipTotal).
|
||||
Filter(s.filter)
|
||||
|
||||
result, err := p.Exec(&[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Fatalf("Expected result %v, got \n%v", s.expectResult, string(encoded))
|
||||
}
|
||||
|
||||
if len(s.expectQueries) != len(testDB.CalledQueries) {
|
||||
t.Fatalf("Expected %d queries, got %d: \n%v", len(s.expectQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
|
||||
for _, q := range testDB.CalledQueries {
|
||||
if !list.ExistInSliceWithRegex(q, s.expectQueries) {
|
||||
t.Fatalf("Didn't expect query \n%v \nin \n%v", q, s.expectQueries)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFilterAndSortLimits(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
filter []FilterData
|
||||
sort []SortField
|
||||
maxFilterExprLimit int
|
||||
maxSortExprLimit int
|
||||
expectError bool
|
||||
}{
|
||||
// filter
|
||||
{
|
||||
"<= max filter length",
|
||||
[]FilterData{
|
||||
"1=2",
|
||||
FilterData("1='" + strings.Repeat("a", MaxFilterLength-4) + "'"),
|
||||
},
|
||||
[]SortField{},
|
||||
1,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max filter length",
|
||||
[]FilterData{
|
||||
"1=2",
|
||||
FilterData("1='" + strings.Repeat("a", MaxFilterLength-3) + "'"),
|
||||
},
|
||||
[]SortField{},
|
||||
1,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= max filter exprs",
|
||||
[]FilterData{
|
||||
"1=2",
|
||||
"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
|
||||
},
|
||||
[]SortField{},
|
||||
6,
|
||||
0,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max filter exprs",
|
||||
[]FilterData{
|
||||
"1=2",
|
||||
"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
|
||||
},
|
||||
[]SortField{},
|
||||
5,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
|
||||
// sort
|
||||
{
|
||||
"<= max sort field length",
|
||||
[]FilterData{},
|
||||
[]SortField{
|
||||
{"id", SortAsc},
|
||||
{"test1", SortDesc},
|
||||
{strings.Repeat("a", MaxSortFieldLength), SortDesc},
|
||||
},
|
||||
0,
|
||||
10,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max sort field length",
|
||||
[]FilterData{},
|
||||
[]SortField{
|
||||
{"id", SortAsc},
|
||||
{"test1", SortDesc},
|
||||
{strings.Repeat("b", MaxSortFieldLength+1), SortDesc},
|
||||
},
|
||||
0,
|
||||
10,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= max sort exprs",
|
||||
[]FilterData{},
|
||||
[]SortField{
|
||||
{"id", SortAsc},
|
||||
{"test1", SortDesc},
|
||||
},
|
||||
0,
|
||||
2,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max sort exprs",
|
||||
[]FilterData{},
|
||||
[]SortField{
|
||||
{"id", SortAsc},
|
||||
{"test1", SortDesc},
|
||||
},
|
||||
0,
|
||||
1,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testResolver := &testFieldResolver{}
|
||||
p := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Sort(s.sort).
|
||||
Filter(s.filter).
|
||||
MaxFilterExprLimit(s.maxFilterExprLimit).
|
||||
MaxSortExprLimit(s.maxSortExprLimit)
|
||||
|
||||
_, err := p.Exec(&[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParseAndExec(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
queryString string
|
||||
expectError bool
|
||||
expectResult string
|
||||
}{
|
||||
{
|
||||
"no extra query params (aka. use the provider presets)",
|
||||
"",
|
||||
false,
|
||||
`{"items":[],"page":2,"perPage":123,"totalItems":2,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"invalid query",
|
||||
"invalid;",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid page",
|
||||
"page=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid perPage",
|
||||
"perPage=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid skipTotal",
|
||||
"skipTotal=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid sorting field",
|
||||
"sort=-unknown",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid filter field",
|
||||
"filter=unknown>1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"page > existing",
|
||||
"page=3&perPage=9999",
|
||||
false,
|
||||
`{"items":[],"page":3,"perPage":1000,"totalItems":2,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"valid query params",
|
||||
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3",
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":1,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"valid query params with skipTotal=1",
|
||||
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3&skipTotal=1",
|
||||
false,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":-1,"totalPages":-1}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
provider := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(2).
|
||||
PerPage(123).
|
||||
Sort([]SortField{{"test2", SortAsc}}).
|
||||
Filter([]FilterData{"test1 > 0"})
|
||||
|
||||
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
expectedQueries := 2
|
||||
if provider.skipTotal {
|
||||
expectedQueries = 1
|
||||
}
|
||||
|
||||
if len(testDB.CalledQueries) != expectedQueries {
|
||||
t.Fatalf("Expected %d db queries, got %d: \n%v", expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Fatalf("Expected result \n%v\ngot\n%v", s.expectResult, string(encoded))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testTableStruct struct {
|
||||
Test1 int `db:"test1" json:"test1"`
|
||||
Test2 string `db:"test2" json:"test2"`
|
||||
Test3 string `db:"test3" json:"test3"`
|
||||
}
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
// using a shared cache to allow multiple connections access to
|
||||
// the same in memory database https://www.sqlite.org/inmemorydb.html
|
||||
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.CreateTable("test", map[string]string{
|
||||
"id": "int default 0",
|
||||
"test1": "int default 0",
|
||||
"test2": "text default ''",
|
||||
"test3": "text default ''",
|
||||
strings.Repeat("a", MaxSortFieldLength): "text default ''",
|
||||
strings.Repeat("b", MaxSortFieldLength+1): "text default ''",
|
||||
}).Execute()
|
||||
db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute()
|
||||
db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute()
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
type testFieldResolver struct {
|
||||
UpdateQueryCalls int
|
||||
ResolveCalls int
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
t.UpdateQueryCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) Resolve(field string) (*ResolverResult, error) {
|
||||
t.ResolveCalls++
|
||||
|
||||
if field == "unknown" {
|
||||
return nil, errors.New("test error")
|
||||
}
|
||||
|
||||
return &ResolverResult{Identifier: field}, nil
|
||||
}
|
113
tools/search/simple_field_resolver.go
Normal file
113
tools/search/simple_field_resolver.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// ResolverResult defines a single FieldResolver.Resolve() successfully parsed result.
|
||||
type ResolverResult struct {
|
||||
// Identifier is the plain SQL identifier/column that will be used
|
||||
// in the final db expression as left or right operand.
|
||||
Identifier string
|
||||
|
||||
// NoCoalesce instructs to not use COALESCE or NULL fallbacks
|
||||
// when building the identifier expression.
|
||||
NoCoalesce bool
|
||||
|
||||
// Params is a map with db placeholder->value pairs that will be added
|
||||
// to the query when building both resolved operands/sides in a single expression.
|
||||
Params dbx.Params
|
||||
|
||||
// MultiMatchSubQuery is an optional sub query expression that will be added
|
||||
// in addition to the combined ResolverResult expression during build.
|
||||
MultiMatchSubQuery dbx.Expression
|
||||
|
||||
// AfterBuild is an optional function that will be called after building
|
||||
// and combining the result of both resolved operands/sides in a single expression.
|
||||
AfterBuild func(expr dbx.Expression) dbx.Expression
|
||||
}
|
||||
|
||||
// FieldResolver defines an interface for managing search fields.
|
||||
type FieldResolver interface {
|
||||
// UpdateQuery allows to updated the provided db query based on the
|
||||
// resolved search fields (eg. adding joins aliases, etc.).
|
||||
//
|
||||
// Called internally by `search.Provider` before executing the search request.
|
||||
UpdateQuery(query *dbx.SelectQuery) error
|
||||
|
||||
// Resolve parses the provided field and returns a properly
|
||||
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
|
||||
Resolve(field string) (*ResolverResult, error)
|
||||
}
|
||||
|
||||
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
|
||||
// provided `allowedFields`.
|
||||
//
|
||||
// Each `allowedFields` could be a plain string (eg. "name")
|
||||
// or a regexp pattern (eg. `^\w+[\w\.]*$`).
|
||||
func NewSimpleFieldResolver(allowedFields ...string) *SimpleFieldResolver {
|
||||
return &SimpleFieldResolver{
|
||||
allowedFields: allowedFields,
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleFieldResolver defines a generic search resolver that allows
|
||||
// only its listed fields to be resolved and take part in a search query.
|
||||
//
|
||||
// If `allowedFields` are empty no fields filtering is applied.
|
||||
type SimpleFieldResolver struct {
|
||||
allowedFields []string
|
||||
}
|
||||
|
||||
// UpdateQuery implements `search.UpdateQuery` interface.
|
||||
func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
// nothing to update...
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve implements `search.Resolve` interface.
|
||||
//
|
||||
// Returns error if `field` is not in `r.allowedFields`.
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
|
||||
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
|
||||
return nil, fmt.Errorf("failed to resolve field %q", field)
|
||||
}
|
||||
|
||||
parts := strings.Split(field, ".")
|
||||
|
||||
// single regular field
|
||||
if len(parts) == 1 {
|
||||
return &ResolverResult{
|
||||
Identifier: "[[" + inflector.Columnify(parts[0]) + "]]",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// treat as json path
|
||||
var jsonPath strings.Builder
|
||||
jsonPath.WriteString("$")
|
||||
for _, part := range parts[1:] {
|
||||
if _, err := strconv.Atoi(part); err == nil {
|
||||
jsonPath.WriteString("[")
|
||||
jsonPath.WriteString(inflector.Columnify(part))
|
||||
jsonPath.WriteString("]")
|
||||
} else {
|
||||
jsonPath.WriteString(".")
|
||||
jsonPath.WriteString(inflector.Columnify(part))
|
||||
}
|
||||
}
|
||||
|
||||
return &ResolverResult{
|
||||
NoCoalesce: true,
|
||||
Identifier: fmt.Sprintf(
|
||||
"JSON_EXTRACT([[%s]], '%s')",
|
||||
inflector.Columnify(parts[0]),
|
||||
jsonPath.String(),
|
||||
),
|
||||
}, nil
|
||||
}
|
87
tools/search/simple_field_resolver_test.go
Normal file
87
tools/search/simple_field_resolver_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package search_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectQuery string
|
||||
}{
|
||||
// missing field (the query shouldn't change)
|
||||
{"", `SELECT "id" FROM "test"`},
|
||||
// unknown field (the query shouldn't change)
|
||||
{"unknown", `SELECT "id" FROM "test"`},
|
||||
// allowed field (the query shouldn't change)
|
||||
{"test", `SELECT "id" FROM "test"`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
|
||||
r.Resolve(s.fieldName)
|
||||
|
||||
if err := r.UpdateQuery(nil); err != nil {
|
||||
t.Fatalf("UpdateQuery failed with error %v", err)
|
||||
}
|
||||
|
||||
rawQuery := query.Build().SQL()
|
||||
|
||||
if rawQuery != s.expectQuery {
|
||||
t.Fatalf("Expected query %v, got \n%v", s.expectQuery, rawQuery)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test", `^test_regex\d+$`, "Test columnify!", "data.test")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectError bool
|
||||
expectName string
|
||||
}{
|
||||
{"", true, ""},
|
||||
{" ", true, ""},
|
||||
{"unknown", true, ""},
|
||||
{"test", false, "[[test]]"},
|
||||
{"test.sub", true, ""},
|
||||
{"test_regex", true, ""},
|
||||
{"test_regex1", false, "[[test_regex1]]"},
|
||||
{"Test columnify!", false, "[[Testcolumnify]]"},
|
||||
{"data.test", false, "JSON_EXTRACT([[data]], '$.test')"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
|
||||
r, err := r.Resolve(s.fieldName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Identifier != s.expectName {
|
||||
t.Fatalf("Expected r.Identifier %q, got %q", s.expectName, r.Identifier)
|
||||
}
|
||||
|
||||
if len(r.Params) != 0 {
|
||||
t.Fatalf("r.Params should be empty, got %v", r.Params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
67
tools/search/sort.go
Normal file
67
tools/search/sort.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
randomSortKey string = "@random"
|
||||
rowidSortKey string = "@rowid"
|
||||
)
|
||||
|
||||
// sort field directions
|
||||
const (
|
||||
SortAsc string = "ASC"
|
||||
SortDesc string = "DESC"
|
||||
)
|
||||
|
||||
// SortField defines a single search sort field.
|
||||
type SortField struct {
|
||||
Name string `json:"name"`
|
||||
Direction string `json:"direction"`
|
||||
}
|
||||
|
||||
// BuildExpr resolves the sort field into a valid db sort expression.
|
||||
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
|
||||
// special case for random sort
|
||||
if s.Name == randomSortKey {
|
||||
return "RANDOM()", nil
|
||||
}
|
||||
|
||||
// special case for the builtin SQLite rowid column
|
||||
if s.Name == rowidSortKey {
|
||||
return fmt.Sprintf("[[_rowid_]] %s", s.Direction), nil
|
||||
}
|
||||
|
||||
result, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
if err != nil || len(result.Params) > 0 || result.Identifier == "" || strings.ToLower(result.Identifier) == "null" {
|
||||
return "", fmt.Errorf("invalid sort field %q", s.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s", result.Identifier, s.Direction), nil
|
||||
}
|
||||
|
||||
// ParseSortFromString parses the provided string expression
|
||||
// into a slice of SortFields.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// fields := search.ParseSortFromString("-name,+created")
|
||||
func ParseSortFromString(str string) (fields []SortField) {
|
||||
data := strings.Split(str, ",")
|
||||
|
||||
for _, field := range data {
|
||||
// trim whitespaces
|
||||
field = strings.TrimSpace(field)
|
||||
if strings.HasPrefix(field, "-") {
|
||||
fields = append(fields, SortField{strings.TrimPrefix(field, "-"), SortDesc})
|
||||
} else {
|
||||
fields = append(fields, SortField{strings.TrimPrefix(field, "+"), SortAsc})
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
78
tools/search/sort_test.go
Normal file
78
tools/search/sort_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package search_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSortFieldBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
|
||||
|
||||
scenarios := []struct {
|
||||
sortField search.SortField
|
||||
expectError bool
|
||||
expectExpression string
|
||||
}{
|
||||
// empty
|
||||
{search.SortField{"", search.SortDesc}, true, ""},
|
||||
// unknown field
|
||||
{search.SortField{"unknown", search.SortAsc}, true, ""},
|
||||
// placeholder field
|
||||
{search.SortField{"'test'", search.SortAsc}, true, ""},
|
||||
// null field
|
||||
{search.SortField{"null", search.SortAsc}, true, ""},
|
||||
// allowed field - asc
|
||||
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
|
||||
// allowed field - desc
|
||||
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
|
||||
// special @random field (ignore direction)
|
||||
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
|
||||
// special _rowid_ field
|
||||
{search.SortField{"@rowid", search.SortDesc}, false, "[[_rowid_]] DESC"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%s_%s", s.sortField.Name, s.sortField.Name), func(t *testing.T) {
|
||||
result, err := s.sortField.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if result != s.expectExpression {
|
||||
t.Fatalf("Expected expression %v, got %v", s.expectExpression, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortFromString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value string
|
||||
expected string
|
||||
}{
|
||||
{"", `[{"name":"","direction":"ASC"}]`},
|
||||
{"test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"+test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"-test", `[{"name":"test","direction":"DESC"}]`},
|
||||
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
|
||||
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
|
||||
{"-@rowid,-test", `[{"name":"@rowid","direction":"DESC"},{"name":"test","direction":"DESC"}]`},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.value, func(t *testing.T) {
|
||||
result := search.ParseSortFromString(s.value)
|
||||
encoded, _ := json.Marshal(result)
|
||||
encodedStr := string(encoded)
|
||||
|
||||
if encodedStr != s.expected {
|
||||
t.Fatalf("Expected expression %s, got %s", s.expected, encodedStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
56
tools/search/token_functions.go
Normal file
56
tools/search/token_functions.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ganigeorgiev/fexpr"
|
||||
)
|
||||
|
||||
var TokenFunctions = map[string]func(
|
||||
argTokenResolverFunc func(fexpr.Token) (*ResolverResult, error),
|
||||
args ...fexpr.Token,
|
||||
) (*ResolverResult, error){
|
||||
// geoDistance(lonA, latA, lonB, latB) calculates the Haversine
|
||||
// distance between 2 points in kilometres (https://www.movable-type.co.uk/scripts/latlong.html).
|
||||
//
|
||||
// The accepted arguments at the moment could be either a plain number or a column identifier (including NULL).
|
||||
// If the column identifier cannot be resolved and converted to a numeric value, it resolves to NULL.
|
||||
//
|
||||
// Similar to the built-in SQLite functions, geoDistance doesn't apply
|
||||
// a "match-all" constraints in case there are multiple relation fields arguments.
|
||||
// Or in other words, if a collection has "orgs" multiple relation field pointing to "orgs" collection that has "office" as "geoPoint" field,
|
||||
// then the filter: `geoDistance(orgs.office.lon, orgs.office.lat, 1, 2) < 200`
|
||||
// will evaluate to true if for at-least-one of the "orgs.office" records the function result in a value satisfying the condition (aka. "result < 200").
|
||||
"geoDistance": func(argTokenResolverFunc func(fexpr.Token) (*ResolverResult, error), args ...fexpr.Token) (*ResolverResult, error) {
|
||||
if len(args) != 4 {
|
||||
return nil, fmt.Errorf("[geoDistance] expected 4 arguments, got %d", len(args))
|
||||
}
|
||||
|
||||
resolvedArgs := make([]*ResolverResult, 4)
|
||||
for i, arg := range args {
|
||||
if arg.Type != fexpr.TokenIdentifier && arg.Type != fexpr.TokenNumber {
|
||||
return nil, fmt.Errorf("[geoDistance] argument %d must be an identifier or number", i)
|
||||
}
|
||||
resolved, err := argTokenResolverFunc(arg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[geoDistance] failed to resolve argument %d: %w", i, err)
|
||||
}
|
||||
resolvedArgs[i] = resolved
|
||||
}
|
||||
|
||||
lonA := resolvedArgs[0].Identifier
|
||||
latA := resolvedArgs[1].Identifier
|
||||
lonB := resolvedArgs[2].Identifier
|
||||
latB := resolvedArgs[3].Identifier
|
||||
|
||||
return &ResolverResult{
|
||||
NoCoalesce: true,
|
||||
Identifier: `(6371 * acos(` +
|
||||
`cos(radians(` + latA + `)) * cos(radians(` + latB + `)) * ` +
|
||||
`cos(radians(` + lonB + `) - radians(` + lonA + `)) + ` +
|
||||
`sin(radians(` + latA + `)) * sin(radians(` + latB + `))` +
|
||||
`))`,
|
||||
Params: mergeParams(resolvedArgs[0].Params, resolvedArgs[1].Params, resolvedArgs[2].Params, resolvedArgs[3].Params),
|
||||
}, nil
|
||||
},
|
||||
}
|
277
tools/search/token_functions_test.go
Normal file
277
tools/search/token_functions_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ganigeorgiev/fexpr"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestTokenFunctionsGeoDistance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
fn, ok := TokenFunctions["geoDistance"]
|
||||
if !ok {
|
||||
t.Error("Expected geoDistance token function to be registered.")
|
||||
}
|
||||
|
||||
baseTokenResolver := func(t fexpr.Token) (*ResolverResult, error) {
|
||||
placeholder := "t" + security.PseudorandomString(5)
|
||||
return &ResolverResult{Identifier: "{:" + placeholder + "}", Params: map[string]any{placeholder: t.Literal}}, nil
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
args []fexpr.Token
|
||||
resolver func(t fexpr.Token) (*ResolverResult, error)
|
||||
result *ResolverResult
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
"no args",
|
||||
nil,
|
||||
baseTokenResolver,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"< 4 args",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenNumber},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"> 4 args",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenNumber},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
{Literal: "5", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unsupported function argument",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenFunction},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unsupported text argument",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenText},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"4 valid arguments but with resolver error",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenNumber},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
},
|
||||
func(t fexpr.Token) (*ResolverResult, error) {
|
||||
return nil, errors.New("test")
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"4 valid arguments",
|
||||
[]fexpr.Token{
|
||||
{Literal: "1", Type: fexpr.TokenNumber},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "3", Type: fexpr.TokenNumber},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
&ResolverResult{
|
||||
NoCoalesce: true,
|
||||
Identifier: `(6371 * acos(cos(radians({:latA})) * cos(radians({:latB})) * cos(radians({:lonB}) - radians({:lonA})) + sin(radians({:latA})) * sin(radians({:latB}))))`,
|
||||
Params: map[string]any{
|
||||
"lonA": 1,
|
||||
"latA": 2,
|
||||
"lonB": 3,
|
||||
"latB": 4,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"mixed arguments",
|
||||
[]fexpr.Token{
|
||||
{Literal: "null", Type: fexpr.TokenIdentifier},
|
||||
{Literal: "2", Type: fexpr.TokenNumber},
|
||||
{Literal: "false", Type: fexpr.TokenIdentifier},
|
||||
{Literal: "4", Type: fexpr.TokenNumber},
|
||||
},
|
||||
baseTokenResolver,
|
||||
&ResolverResult{
|
||||
NoCoalesce: true,
|
||||
Identifier: `(6371 * acos(cos(radians({:latA})) * cos(radians({:latB})) * cos(radians({:lonB}) - radians({:lonA})) + sin(radians({:latA})) * sin(radians({:latB}))))`,
|
||||
Params: map[string]any{
|
||||
"lonA": "null",
|
||||
"latA": 2,
|
||||
"lonB": false,
|
||||
"latB": 4,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result, err := fn(s.resolver, s.args...)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
|
||||
}
|
||||
|
||||
testCompareResults(t, s.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenFunctionsGeoDistanceExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
fn, ok := TokenFunctions["geoDistance"]
|
||||
if !ok {
|
||||
t.Error("Expected geoDistance token function to be registered.")
|
||||
}
|
||||
|
||||
result, err := fn(
|
||||
func(t fexpr.Token) (*ResolverResult, error) {
|
||||
placeholder := "t" + security.PseudorandomString(5)
|
||||
return &ResolverResult{Identifier: "{:" + placeholder + "}", Params: map[string]any{placeholder: t.Literal}}, nil
|
||||
},
|
||||
fexpr.Token{Literal: "23.23033854945808", Type: fexpr.TokenNumber},
|
||||
fexpr.Token{Literal: "42.713146090563384", Type: fexpr.TokenNumber},
|
||||
fexpr.Token{Literal: "23.44920680886216", Type: fexpr.TokenNumber},
|
||||
fexpr.Token{Literal: "42.7078484153991", Type: fexpr.TokenNumber},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
column := []float64{}
|
||||
err = testDB.NewQuery("select " + result.Identifier).Bind(result.Params).Column(&column)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(column) != 1 {
|
||||
t.Fatalf("Expected exactly 1 column value as result, got %v", column)
|
||||
}
|
||||
|
||||
expected := "17.89"
|
||||
distance := fmt.Sprintf("%.2f", column[0])
|
||||
if distance != expected {
|
||||
t.Fatalf("Expected distance value %s, got %s", expected, distance)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func testCompareResults(t *testing.T, a, b *ResolverResult) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
aIsNil := a == nil
|
||||
bIsNil := b == nil
|
||||
if aIsNil != bIsNil {
|
||||
t.Fatalf("Expected aIsNil and bIsNil to be the same, got %v vs %v", aIsNil, bIsNil)
|
||||
}
|
||||
|
||||
if aIsNil && bIsNil {
|
||||
return
|
||||
}
|
||||
|
||||
aHasAfterBuild := a.AfterBuild == nil
|
||||
bHasAfterBuild := b.AfterBuild == nil
|
||||
if aHasAfterBuild != bHasAfterBuild {
|
||||
t.Fatalf("Expected aHasAfterBuild and bHasAfterBuild to be the same, got %v vs %v", aHasAfterBuild, bHasAfterBuild)
|
||||
}
|
||||
|
||||
var aAfterBuild string
|
||||
if a.AfterBuild != nil {
|
||||
aAfterBuild = a.AfterBuild(dbx.NewExp("test")).Build(testDB.DB, a.Params)
|
||||
}
|
||||
var bAfterBuild string
|
||||
if b.AfterBuild != nil {
|
||||
bAfterBuild = b.AfterBuild(dbx.NewExp("test")).Build(testDB.DB, a.Params)
|
||||
}
|
||||
if aAfterBuild != bAfterBuild {
|
||||
t.Fatalf("Expected bAfterBuild and bAfterBuild to be the same, got\n%s\nvs\n%s", aAfterBuild, bAfterBuild)
|
||||
}
|
||||
|
||||
var aMultiMatchSubQuery string
|
||||
if a.MultiMatchSubQuery != nil {
|
||||
aMultiMatchSubQuery = a.MultiMatchSubQuery.Build(testDB.DB, a.Params)
|
||||
}
|
||||
var bMultiMatchSubQuery string
|
||||
if b.MultiMatchSubQuery != nil {
|
||||
bMultiMatchSubQuery = b.MultiMatchSubQuery.Build(testDB.DB, b.Params)
|
||||
}
|
||||
if aMultiMatchSubQuery != bMultiMatchSubQuery {
|
||||
t.Fatalf("Expected bMultiMatchSubQuery and bMultiMatchSubQuery to be the same, got\n%s\nvs\n%s", aMultiMatchSubQuery, bMultiMatchSubQuery)
|
||||
}
|
||||
|
||||
if a.NoCoalesce != b.NoCoalesce {
|
||||
t.Fatalf("Expected NoCoalesce to match, got %v vs %v", a.NoCoalesce, b.NoCoalesce)
|
||||
}
|
||||
|
||||
// loose placeholders replacement
|
||||
var aResolved = a.Identifier
|
||||
for k, v := range a.Params {
|
||||
aResolved = strings.ReplaceAll(aResolved, "{:"+k+"}", fmt.Sprintf("%v", v))
|
||||
}
|
||||
var bResolved = b.Identifier
|
||||
for k, v := range b.Params {
|
||||
bResolved = strings.ReplaceAll(bResolved, "{:"+k+"}", fmt.Sprintf("%v", v))
|
||||
}
|
||||
if aResolved != bResolved {
|
||||
t.Fatalf("Expected resolved identifiers to match, got\n%s\nvs\n%s", aResolved, bResolved)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue