1
0
Fork 0

Merging upstream version 11.4.5.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:48:10 +01:00
parent 0a06643852
commit 88f99e1c27
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
131 changed files with 53004 additions and 37079 deletions

View file

@ -1,5 +1,5 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -108,6 +108,7 @@ class TypeAnnotator:
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
@ -116,6 +117,7 @@ class TypeAnnotator:
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@ -296,9 +298,6 @@ class TypeAnnotator:
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
if expression.type:
return expression # We've already inferred the expression's type
@ -311,9 +310,8 @@ class TypeAnnotator:
)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_collection(value):
self._maybe_annotate(v)
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression

View file

@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
and b.type.this != exp.DataType.Type.DATE
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")

View file

@ -1,7 +1,6 @@
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_joins(expression):
@ -179,6 +178,4 @@ def join_condition(join):
for condition in conditions:
extract_condition(condition)
on = simplify(on)
remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition
return source_key, join_key, on

View file

@ -3,7 +3,6 @@ import itertools
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
eliminate_subqueries(expression.this)
return expression
expression = simplify(expression)
root = build_scope(expression)
# Map of alias->Scope|Table

View file

@ -1,5 +1,4 @@
from sqlglot import exp
from sqlglot.helper import ensure_collection
def lower_identities(expression):
@ -40,13 +39,10 @@ def lower_identities(expression):
lower_identities(expression.right)
traversed |= {"this", "expression"}
for k, v in expression.args.items():
for k, v in expression.iter_expressions():
if k in traversed:
continue
for child in ensure_collection(v):
if isinstance(child, exp.Expression):
child.transform(_lower, copy=False)
v.transform(_lower, copy=False)
return expression

View file

@ -3,7 +3,6 @@ from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def merge_subqueries(expression, leave_tables_isolated=False):
@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
expression.set("where", simplify(expression.args.get("where")))
expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):

View file

@ -1,29 +1,63 @@
from __future__ import annotations
import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
from sqlglot.optimizer.simplify import flatten, uniq_sort
logger = logging.getLogger("sqlglot")
def normalize(expression, dnf=False, max_distance=128):
def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
Rewrite sqlglot AST into conjunctive normal form.
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression).sql()
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
expression (sqlglot.Expression): expression to normalize
dnf (bool): rewrite in disjunctive normal form instead
max_distance (int): the maximal estimated distance from cnf to attempt conversion
expression: expression to normalize
dnf: rewrite in disjunctive normal form instead.
max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
expression = simplify(expression)
cache: t.Dict[int, str] = {}
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
return simplify(expression)
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
logger.info(
f"Skipping normalization because distance {distance} exceeds max {max_distance}"
)
return expression
root = node is expression
original = node.copy()
try:
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
except OptimizeError as e:
logger.info(e)
node.replace(original)
if root:
return original
return expression
if root:
expression = node
return expression
def normalized(expression, dnf=False):
@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@ -64,30 +98,33 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return [1]
return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
return [
return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
]
)
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance):
def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
if isinstance(expression.unnest(), exp.Connector):
if normalization_distance(expression, dnf) > max_distance:
return expression
if normalized(expression, dnf=dnf):
return expression
distance = normalization_distance(expression, dnf=dnf)
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
return _distribute(a, b, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func)
return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func)
return _distribute(a, b, from_func, to_func, cache)
return expression
def _distribute(a, b, from_func, to_func):
def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
exp.paren(from_func(c, b.left)),
exp.paren(from_func(c, b.right)),
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
a = to_func(from_func(a, b.left), from_func(a, b.right))
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
)
return _simplify(a)
def _simplify(node):
node = uniq_sort(flatten(node))
exp.replace_children(node, _simplify)
return node
return a

View file

@ -1,6 +1,5 @@
from sqlglot import exp
from sqlglot.helper import tsort
from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
@ -29,7 +28,6 @@ def optimize_joins(expression):
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():

View file

@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@ -43,6 +44,7 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
simplify,
)
@ -78,7 +80,7 @@ def optimize(
Returns:
sqlglot.Expression: optimized expression
"""
schema = ensure_schema(schema or sqlglot.schema)
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:

View file

@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
_expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.expression.find_all(exp.Join))
joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to source names
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
tables = column_tables.setdefault(identifier, [])
# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
if table not in tables:
tables.append(table)
tables[table] = None
if join_table not in tables:
tables.append(join_table)
tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
return column_tables
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
def _expand_alias_refs(scope, resolver):
selects = {}
# Replace references to select aliases
def transform(node, *_):
@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
selects = {s.alias_or_name: s for s in scope.selects}
if not selects:
for s in scope.selects:
selects[s.alias_or_name] = s
select = selects.get(node.name)
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
group.transform(transform, copy=False)
for select in scope.expression.selects:
select.transform(transform, copy=False)
for modifier in ("where", "group"):
part = scope.expression.args.get(modifier)
if part:
part.transform(transform, copy=False)
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
order = scope.expression.args.get("order")
if order:
for ordered in order.expressions:
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
having = scope.expression.args.get("having")
if having:
for column in having.find_all(exp.Column):
if (
not column.table
@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
def _expand_stars(scope, resolver):
def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
continue
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
new_selections.append(
exp.alias_(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)

View file

@ -160,7 +160,7 @@ class Scope:
Yields:
exp.Expression: nodes
"""
for expression, _, _ in self.walk(bfs=bfs):
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression

View file

@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
cache = {}
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
node = uniq_sort(node, cache, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node)
node = remove_compliments(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
node.parent = expression.parent
node = simplify_literals(node)
node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
if expression.this == FALSE:
if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@ -104,42 +105,42 @@ def flatten(expression):
return expression
def simplify_connectors(expression):
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if isinstance(expression, exp.Connector):
if left == right:
if left == right:
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
return exp.false()
if is_null(left) or is_null(right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
if isinstance(expression, exp.And):
if FALSE in (left, right):
return exp.false()
if NULL in (left, right):
return exp.null()
if always_true(left) and always_true(right):
return exp.true()
if always_true(left):
return right
if always_true(right):
return left
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if left == FALSE and right == FALSE:
return exp.false()
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return exp.null()
if left == FALSE:
return right
if right == FALSE:
return left
return _simplify_comparison(expression, left, right, or_=True)
return None
return _simplify_comparison(expression, left, right)
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if is_false(left) and is_false(right):
return exp.false()
if (
(is_null(left) and is_null(right))
or (is_null(left) and is_false(right))
or (is_false(left) and is_null(right))
):
return exp.null()
if is_false(left):
return right
if is_false(right):
return left
return _simplify_comparison(expression, left, right, or_=True)
return _flat_simplify(expression, _simplify_connectors)
if isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_connectors, root)
return expression
LT_LTE = (exp.LT, exp.LTE)
@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
def remove_compliments(expression):
def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
def uniq_sort(expression):
def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e): e for e in flattened}
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
def absorb_and_eliminate(expression):
def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
return _flat_simplify(expression, _simplify_binary)
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
if c == NULL:
if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
if a == NULL:
if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
elif NULL in (a, b):
elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
boolean = eval_boolean(expression, a, b)
boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
or isinstance(expression.this, (exp.Is, exp.Like))
or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
return expression == TRUE or isinstance(expression, exp.Literal)
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
expression, exp.Literal
)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
def is_false(a: exp.Expression) -> bool:
return type(a) is exp.Boolean and not a.this
def is_null(a: exp.Expression) -> bool:
return type(a) is exp.Null
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
def _flat_simplify(expression, simplifier):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
def _flat_simplify(expression, simplifier, root=True):
if root or not expression.same_parent:
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
while queue:
a = queue.popleft()
for b in queue:
result = simplifier(expression, a, b)
for b in queue:
result = simplifier(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression