Merging upstream version 11.4.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0a06643852
commit
88f99e1c27
131 changed files with 53004 additions and 37079 deletions
|
@ -1,5 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection, ensure_list, subclasses
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -108,6 +108,7 @@ class TypeAnnotator:
|
|||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
|
@ -116,6 +117,7 @@ class TypeAnnotator:
|
|||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
|
@ -296,9 +298,6 @@ class TypeAnnotator:
|
|||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return None
|
||||
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
|
@ -311,9 +310,8 @@ class TypeAnnotator:
|
|||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
for value in expression.args.values():
|
||||
for v in ensure_collection(value):
|
||||
self._maybe_annotate(v)
|
||||
for _, value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and b.type
|
||||
and b.type.this != exp.DataType.Type.DATE
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
):
|
||||
_replace_cast(b, "date")
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_joins(expression):
|
||||
|
@ -179,6 +178,4 @@ def join_condition(join):
|
|||
for condition in conditions:
|
||||
extract_condition(condition)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.true() else on
|
||||
return source_key, join_key, remaining_condition
|
||||
return source_key, join_key, on
|
||||
|
|
|
@ -3,7 +3,6 @@ import itertools
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
|
@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
|
|||
eliminate_subqueries(expression.this)
|
||||
return expression
|
||||
|
||||
expression = simplify(expression)
|
||||
root = build_scope(expression)
|
||||
|
||||
# Map of alias->Scope|Table
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_collection
|
||||
|
||||
|
||||
def lower_identities(expression):
|
||||
|
@ -40,13 +39,10 @@ def lower_identities(expression):
|
|||
lower_identities(expression.right)
|
||||
traversed |= {"this", "expression"}
|
||||
|
||||
for k, v in expression.args.items():
|
||||
for k, v in expression.iter_expressions():
|
||||
if k in traversed:
|
||||
continue
|
||||
|
||||
for child in ensure_collection(v):
|
||||
if isinstance(child, exp.Expression):
|
||||
child.transform(_lower, copy=False)
|
||||
v.transform(_lower, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ from collections import defaultdict
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def merge_subqueries(expression, leave_tables_isolated=False):
|
||||
|
@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
|
||||
if set(exp.column_table_names(where.this)) <= sources:
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", simplify(from_or_join.args.get("on")))
|
||||
from_or_join.set("on", from_or_join.args.get("on"))
|
||||
return
|
||||
|
||||
expression.where(where.this, copy=False)
|
||||
expression.set("where", simplify(expression.args.get("where")))
|
||||
expression.set("where", expression.args.get("where"))
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
|
|
|
@ -1,29 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
|
||||
from sqlglot.optimizer.simplify import flatten, uniq_sort
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def normalize(expression, dnf=False, max_distance=128):
|
||||
def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
|
||||
"""
|
||||
Rewrite sqlglot AST into conjunctive normal form.
|
||||
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("(x AND y) OR z")
|
||||
>>> normalize(expression).sql()
|
||||
>>> normalize(expression, dnf=False).sql()
|
||||
'(x OR z) AND (y OR z)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to normalize
|
||||
dnf (bool): rewrite in disjunctive normal form instead
|
||||
max_distance (int): the maximal estimated distance from cnf to attempt conversion
|
||||
expression: expression to normalize
|
||||
dnf: rewrite in disjunctive normal form instead.
|
||||
max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
|
||||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
expression = simplify(expression)
|
||||
cache: t.Dict[int, str] = {}
|
||||
|
||||
expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
return simplify(expression)
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
continue
|
||||
|
||||
distance = normalization_distance(node, dnf=dnf)
|
||||
|
||||
if distance > max_distance:
|
||||
logger.info(
|
||||
f"Skipping normalization because distance {distance} exceeds max {max_distance}"
|
||||
)
|
||||
return expression
|
||||
|
||||
root = node is expression
|
||||
original = node.copy()
|
||||
try:
|
||||
node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
node.replace(original)
|
||||
if root:
|
||||
return original
|
||||
return expression
|
||||
|
||||
if root:
|
||||
expression = node
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
|
@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
|
|||
int: difference
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
len(list(expression.find_all(exp.Connector))) + 1
|
||||
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
||||
)
|
||||
|
||||
|
||||
|
@ -64,30 +98,33 @@ def _predicate_lengths(expression, dnf):
|
|||
expression = expression.unnest()
|
||||
|
||||
if not isinstance(expression, exp.Connector):
|
||||
return [1]
|
||||
return (1,)
|
||||
|
||||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
return [
|
||||
return tuple(
|
||||
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
)
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
def distributive_law(expression, dnf, max_distance, cache=None):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
"""
|
||||
if isinstance(expression.unnest(), exp.Connector):
|
||||
if normalization_distance(expression, dnf) > max_distance:
|
||||
return expression
|
||||
if normalized(expression, dnf=dnf):
|
||||
return expression
|
||||
|
||||
distance = normalization_distance(expression, dnf=dnf)
|
||||
|
||||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
a, b = expression.unnest_operands()
|
||||
|
||||
|
@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
|
|||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func, cache)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(a, b, from_func, to_func, cache)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func):
|
||||
def _distribute(a, b, from_func, to_func, cache):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
exp.paren(from_func(c, b.left)),
|
||||
exp.paren(from_func(c, b.right)),
|
||||
uniq_sort(flatten(from_func(c, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.right)), cache),
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(from_func(a, b.left), from_func(a, b.right))
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.right)), cache),
|
||||
)
|
||||
|
||||
return _simplify(a)
|
||||
|
||||
|
||||
def _simplify(node):
|
||||
node = uniq_sort(flatten(node))
|
||||
exp.replace_children(node, _simplify)
|
||||
return node
|
||||
return a
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def optimize_joins(expression):
|
||||
|
@ -29,7 +28,6 @@ def optimize_joins(expression):
|
|||
for name, join in cross_joins:
|
||||
for dep in references.get(name, []):
|
||||
on = dep.args["on"]
|
||||
on = on.replace(simplify(on))
|
||||
|
||||
if isinstance(on, exp.Connector):
|
||||
for predicate in on.flatten():
|
||||
|
|
|
@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
|||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
@ -43,6 +44,7 @@ RULES = (
|
|||
eliminate_ctes,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
simplify,
|
||||
)
|
||||
|
||||
|
||||
|
@ -78,7 +80,7 @@ def optimize(
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema)
|
||||
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
|
|
|
@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
|
|||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
_expand_using(scope, resolver)
|
||||
using_column_tables = _expand_using(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_qualify_outputs(scope)
|
||||
_expand_alias_refs(scope, resolver)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
|
@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
|
|||
|
||||
|
||||
def _expand_using(scope, resolver):
|
||||
joins = list(scope.expression.find_all(exp.Join))
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.this.alias for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to source names
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables = {}
|
||||
|
||||
for join in joins:
|
||||
|
@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
|
|||
)
|
||||
)
|
||||
|
||||
tables = column_tables.setdefault(identifier, [])
|
||||
# Set all values in the dict to None, because we only care about the key ordering
|
||||
tables = column_tables.setdefault(identifier, {})
|
||||
if table not in tables:
|
||||
tables.append(table)
|
||||
tables[table] = None
|
||||
if join_table not in tables:
|
||||
tables.append(join_table)
|
||||
tables[join_table] = None
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions))
|
||||
|
@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
|
|||
|
||||
scope.replace(column, replacement)
|
||||
|
||||
return column_tables
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
def _expand_alias_refs(scope, resolver):
|
||||
selects = {}
|
||||
|
||||
# Replace references to select aliases
|
||||
def transform(node, *_):
|
||||
|
@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
|
|||
node.set("table", table)
|
||||
return node
|
||||
|
||||
selects = {s.alias_or_name: s for s in scope.selects}
|
||||
|
||||
if not selects:
|
||||
for s in scope.selects:
|
||||
selects[s.alias_or_name] = s
|
||||
select = selects.get(node.name)
|
||||
|
||||
if select:
|
||||
scope.clear_cache()
|
||||
if isinstance(select, exp.Alias):
|
||||
|
@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
|
|||
|
||||
return node
|
||||
|
||||
group.transform(transform, copy=False)
|
||||
for select in scope.expression.selects:
|
||||
select.transform(transform, copy=False)
|
||||
|
||||
for modifier in ("where", "group"):
|
||||
part = scope.expression.args.get(modifier)
|
||||
|
||||
if part:
|
||||
part.transform(transform, copy=False)
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
scope.expression.set("group", group)
|
||||
|
||||
|
@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
|
|||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
columns_missing_from_scope = []
|
||||
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
for ordered in order.expressions:
|
||||
for column in ordered.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
and column.parent is not ordered
|
||||
and column.name in resolver.all_columns
|
||||
):
|
||||
columns_missing_from_scope.append(column)
|
||||
|
||||
# Determine whether each reference in the having clause is to a column or an alias.
|
||||
for having in scope.find_all(exp.Having):
|
||||
having = scope.expression.args.get("having")
|
||||
|
||||
if having:
|
||||
for column in having.find_all(exp.Column):
|
||||
if (
|
||||
not column.table
|
||||
|
@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
|
|||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
def _expand_stars(scope, resolver, using_column_tables):
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
except_columns = {}
|
||||
replace_columns = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
for expression in scope.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
|
@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
|
|||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name not in except_columns.get(table_id, set()):
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
continue
|
||||
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
exp.alias_(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
|
||||
)
|
||||
)
|
||||
elif name not in except_columns.get(table_id, set()):
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table)
|
||||
new_selections.append(alias(column, alias_) if alias_ != name else column)
|
||||
|
|
|
@ -160,7 +160,7 @@ class Scope:
|
|||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
|
|
|
@ -5,11 +5,10 @@ from collections import deque
|
|||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.expressions import FALSE, NULL, TRUE
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import first, while_changing
|
||||
|
||||
GENERATOR = Generator(normalize=True, identify=True)
|
||||
GENERATOR = Generator(normalize=True, identify="safe")
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
|
@ -28,18 +27,20 @@ def simplify(expression):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
cache = {}
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node)
|
||||
node = absorb_and_eliminate(node)
|
||||
node = uniq_sort(node, cache, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node)
|
||||
node = remove_compliments(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node)
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_parens(node)
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -70,7 +71,7 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if isinstance(expression.this, exp.Null):
|
||||
if is_null(expression.this):
|
||||
return exp.null()
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
|
@ -78,11 +79,11 @@ def simplify_not(expression):
|
|||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
if isinstance(condition, exp.Null):
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
return exp.false()
|
||||
if expression.this == FALSE:
|
||||
if is_false(expression.this):
|
||||
return exp.true()
|
||||
if isinstance(expression.this, exp.Not):
|
||||
# double negation
|
||||
|
@ -104,42 +105,42 @@ def flatten(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def simplify_connectors(expression):
|
||||
def simplify_connectors(expression, root=True):
|
||||
def _simplify_connectors(expression, left, right):
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left == right:
|
||||
if left == right:
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if is_false(left) or is_false(right):
|
||||
return exp.false()
|
||||
if is_null(left) or is_null(right):
|
||||
return exp.null()
|
||||
if always_true(left) and always_true(right):
|
||||
return exp.true()
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if FALSE in (left, right):
|
||||
return exp.false()
|
||||
if NULL in (left, right):
|
||||
return exp.null()
|
||||
if always_true(left) and always_true(right):
|
||||
return exp.true()
|
||||
if always_true(left):
|
||||
return right
|
||||
if always_true(right):
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right)
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return exp.true()
|
||||
if left == FALSE and right == FALSE:
|
||||
return exp.false()
|
||||
if (
|
||||
(left == NULL and right == NULL)
|
||||
or (left == NULL and right == FALSE)
|
||||
or (left == FALSE and right == NULL)
|
||||
):
|
||||
return exp.null()
|
||||
if left == FALSE:
|
||||
return right
|
||||
if right == FALSE:
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right, or_=True)
|
||||
return None
|
||||
return _simplify_comparison(expression, left, right)
|
||||
elif isinstance(expression, exp.Or):
|
||||
if always_true(left) or always_true(right):
|
||||
return exp.true()
|
||||
if is_false(left) and is_false(right):
|
||||
return exp.false()
|
||||
if (
|
||||
(is_null(left) and is_null(right))
|
||||
or (is_null(left) and is_false(right))
|
||||
or (is_false(left) and is_null(right))
|
||||
):
|
||||
return exp.null()
|
||||
if is_false(left):
|
||||
return right
|
||||
if is_false(right):
|
||||
return left
|
||||
return _simplify_comparison(expression, left, right, or_=True)
|
||||
|
||||
return _flat_simplify(expression, _simplify_connectors)
|
||||
if isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_connectors, root)
|
||||
return expression
|
||||
|
||||
|
||||
LT_LTE = (exp.LT, exp.LTE)
|
||||
|
@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return None
|
||||
|
||||
|
||||
def remove_compliments(expression):
|
||||
def remove_compliments(expression, root=True):
|
||||
"""
|
||||
Removing compliments.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
|
@ -236,23 +237,23 @@ def remove_compliments(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression):
|
||||
def uniq_sort(expression, cache=None, root=True):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {GENERATOR.generate(e): e for e in flattened}
|
||||
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
|
||||
expression = result_func(*(e for _, e in sorted(arr)))
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
|
@ -262,7 +263,7 @@ def uniq_sort(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def absorb_and_eliminate(expression):
|
||||
def absorb_and_eliminate(expression, root=True):
|
||||
"""
|
||||
absorption:
|
||||
A AND (A OR B) -> A
|
||||
|
@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
|
|||
(A AND B) OR (A AND NOT B) -> A
|
||||
(A OR B) AND (A OR NOT B) -> A
|
||||
"""
|
||||
if isinstance(expression, exp.Connector):
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
|
@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def simplify_literals(expression):
|
||||
if isinstance(expression, exp.Binary):
|
||||
return _flat_simplify(expression, _simplify_binary)
|
||||
def simplify_literals(expression, root=True):
|
||||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
|
@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
|
|||
c = b
|
||||
not_ = False
|
||||
|
||||
if c == NULL:
|
||||
if is_null(c):
|
||||
if isinstance(a, exp.Literal):
|
||||
return exp.true() if not_ else exp.false()
|
||||
if a == NULL:
|
||||
if is_null(a):
|
||||
return exp.false() if not_ else exp.true()
|
||||
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||
return None
|
||||
elif NULL in (a, b):
|
||||
elif is_null(a) or is_null(b):
|
||||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
|
@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
|
|||
if boolean:
|
||||
return boolean
|
||||
elif a.is_string and b.is_string:
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
boolean = eval_boolean(expression, a.this, b.this)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
|
@ -381,7 +382,7 @@ def simplify_parens(expression):
|
|||
and not isinstance(expression.this, exp.Select)
|
||||
and (
|
||||
not isinstance(expression.parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(expression.this, (exp.Is, exp.Like))
|
||||
or isinstance(expression.this, exp.Predicate)
|
||||
or not isinstance(expression.this, exp.Binary)
|
||||
)
|
||||
):
|
||||
|
@ -400,13 +401,23 @@ def remove_where_true(expression):
|
|||
|
||||
|
||||
def always_true(expression):
|
||||
return expression == TRUE or isinstance(expression, exp.Literal)
|
||||
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
|
||||
expression, exp.Literal
|
||||
)
|
||||
|
||||
|
||||
def is_complement(a, b):
|
||||
return isinstance(b, exp.Not) and b.this == a
|
||||
|
||||
|
||||
def is_false(a: exp.Expression) -> bool:
|
||||
return type(a) is exp.Boolean and not a.this
|
||||
|
||||
|
||||
def is_null(a: exp.Expression) -> bool:
|
||||
return type(a) is exp.Null
|
||||
|
||||
|
||||
def eval_boolean(expression, a, b):
|
||||
if isinstance(expression, (exp.EQ, exp.Is)):
|
||||
return boolean_literal(a == b)
|
||||
|
@ -466,24 +477,27 @@ def boolean_literal(condition):
|
|||
return exp.true() if condition else exp.false()
|
||||
|
||||
|
||||
def _flat_simplify(expression, simplifier):
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
def _flat_simplify(expression, simplifier, root=True):
|
||||
if root or not expression.same_parent:
|
||||
operands = []
|
||||
queue = deque(expression.flatten(unnest=False))
|
||||
size = len(queue)
|
||||
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
while queue:
|
||||
a = queue.popleft()
|
||||
|
||||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
if result:
|
||||
queue.remove(b)
|
||||
queue.append(result)
|
||||
break
|
||||
else:
|
||||
operands.append(a)
|
||||
|
||||
if len(operands) < size:
|
||||
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
|
||||
if len(operands) < size:
|
||||
return functools.reduce(
|
||||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue