Merging upstream version 18.13.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a56b8dde5c
commit
320822f1c4
76 changed files with 21248 additions and 19605 deletions
|
@ -6,6 +6,7 @@ from sqlglot import exp
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
return expression
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
|
||||
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
|
||||
|
||||
|
||||
def normalization_distance(expression, dnf=False):
|
||||
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
|
||||
"""
|
||||
The difference in the number of predicates between the current expression and the normalized form.
|
||||
Checks whether a given expression is in a normal form of interest.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
|
||||
True
|
||||
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
|
||||
True
|
||||
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
|
||||
False
|
||||
|
||||
Args:
|
||||
expression: The expression to check if it's normalized.
|
||||
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
||||
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
||||
"""
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
return not any(
|
||||
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
|
||||
)
|
||||
|
||||
|
||||
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
|
||||
"""
|
||||
The difference in the number of predicates between a given expression and its normalized form.
|
||||
|
||||
This is used as an estimate of the cost of the conversion which is exponential in complexity.
|
||||
|
||||
|
@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
|
|||
4
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to compute distance
|
||||
dnf (bool): compute to dnf distance instead
|
||||
expression: The expression to compute the normalization distance for.
|
||||
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
||||
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
||||
|
||||
Returns:
|
||||
int: difference
|
||||
The normalization distance.
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
||||
|
|
|
@ -39,10 +39,14 @@ def optimize_joins(expression):
|
|||
if len(other_table_names(dep)) < 2:
|
||||
continue
|
||||
|
||||
operator = type(on)
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.true())
|
||||
join.on(predicate, copy=False)
|
||||
predicate = exp._combine(
|
||||
[join.args.get("on"), predicate], operator, copy=False
|
||||
)
|
||||
join.on(predicate, append=False, copy=False)
|
||||
|
||||
expression = reorder_joins(expression)
|
||||
expression = normalize(expression)
|
||||
|
|
|
@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema
|
|||
SELECT_ALL = object()
|
||||
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||
DEFAULT_SELECTION = lambda is_agg: alias(
|
||||
exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
|
||||
)
|
||||
|
||||
|
||||
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
|
||||
|
@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
new_selections = []
|
||||
removed = False
|
||||
star = False
|
||||
is_agg = False
|
||||
|
||||
select_all = SELECT_ALL in parent_selections
|
||||
|
||||
|
@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
star = True
|
||||
removed = True
|
||||
|
||||
if not is_agg and selection.find(exp.AggFunc):
|
||||
is_agg = True
|
||||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
names = {s.alias_or_name for s in new_selections}
|
||||
|
@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
new_selections.append(DEFAULT_SELECTION(is_agg))
|
||||
|
||||
scope.expression.select(*new_selections, append=False, copy=False)
|
||||
|
||||
|
|
|
@ -137,8 +137,8 @@ class Scope:
|
|||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def walk(self, bfs=True):
|
||||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
def walk(self, bfs=True, prune=None):
|
||||
return walk_in_scope(self.expression, bfs=bfs, prune=None)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
return find_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
@ -731,7 +731,7 @@ def _traverse_ddl(scope):
|
|||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
def walk_in_scope(expression, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
nodes that start child scopes.
|
||||
|
@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
|
|||
expression (exp.Expression):
|
||||
bfs (bool): if set to True the BFS traversal order will be applied,
|
||||
otherwise the DFS traversal will be used instead.
|
||||
prune ((node, parent, arg_key) -> bool): callable that returns True if
|
||||
the generator should stop traversing this branch of the tree.
|
||||
|
||||
Yields:
|
||||
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
|
||||
"""
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
crossed_scope_boundary = False
|
||||
|
||||
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
|
||||
prune = False
|
||||
for node, parent, key in expression.walk(
|
||||
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
|
||||
):
|
||||
crossed_scope_boundary = False
|
||||
|
||||
yield node, parent, key
|
||||
|
||||
|
@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
|
|||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
prune = True
|
||||
crossed_scope_boundary = True
|
||||
|
||||
if isinstance(node, (exp.Subquery, exp.UDTF)):
|
||||
# The following args are not actually in the inner scope, so we should visit them
|
||||
|
|
|
@ -5,9 +5,11 @@ import typing as t
|
|||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
def simplify(expression, constant_propagation=False):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
|
@ -29,6 +31,8 @@ def simplify(expression):
|
|||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to simplify
|
||||
constant_propagation: whether or not the constant propagation rule should be used
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
@ -67,13 +71,16 @@ def simplify(expression):
|
|||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
|
||||
if constant_propagation:
|
||||
node = propagate_constants(node, root)
|
||||
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
|
||||
# Post-order transformations
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node = remove_complements(node, root)
|
||||
node = simplify_coalesce(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
|
@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return None
|
||||
|
||||
|
||||
def remove_compliments(expression, root=True):
|
||||
def remove_complements(expression, root=True):
|
||||
"""
|
||||
Removing compliments.
|
||||
Removing complements.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
return compliment
|
||||
return complement
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
def propagate_constants(expression, root=True):
|
||||
"""
|
||||
Propagate constants for conjunctions in DNF:
|
||||
|
||||
SELECT * FROM t WHERE a = b AND b = 5 becomes
|
||||
SELECT * FROM t WHERE a = 5 AND b = 5
|
||||
|
||||
Reference: https://www.sqlite.org/optoverview.html
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(expression, exp.And)
|
||||
and (root or not expression.same_parent)
|
||||
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
|
||||
):
|
||||
constant_mapping = {}
|
||||
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
|
||||
if isinstance(expr, exp.EQ):
|
||||
l, r = expr.left, expr.right
|
||||
|
||||
# TODO: create a helper that can be used to detect nested literal expressions such
|
||||
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
|
||||
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
|
||||
pass
|
||||
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
|
||||
l, r = r, l
|
||||
else:
|
||||
continue
|
||||
|
||||
constant_mapping[l] = (id(l), r)
|
||||
|
||||
if constant_mapping:
|
||||
for column in find_all_in_scope(expression, exp.Column):
|
||||
parent = column.parent
|
||||
column_id, constant = constant_mapping.get(column) or (None, None)
|
||||
if (
|
||||
column_id is not None
|
||||
and id(column) != column_id
|
||||
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
|
||||
):
|
||||
column.replace(constant.copy())
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
||||
exp.DateAdd: exp.Sub,
|
||||
exp.DateSub: exp.Add,
|
||||
|
@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
|
|||
|
||||
def simplify_concat(expression):
|
||||
"""Reduces all groups that contain string literals by concatenating them."""
|
||||
if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
|
||||
if not isinstance(expression, CONCATS) or (
|
||||
# We can't reduce a CONCAT_WS call if we don't statically know the separator
|
||||
isinstance(expression, exp.ConcatWs)
|
||||
and not expression.expressions[0].is_string
|
||||
):
|
||||
return expression
|
||||
|
||||
if isinstance(expression, exp.ConcatWs):
|
||||
sep_expr, *expressions = expression.expressions
|
||||
sep = sep_expr.name
|
||||
concat_type = exp.ConcatWs
|
||||
else:
|
||||
expressions = expression.expressions
|
||||
sep = ""
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
|
||||
new_args = []
|
||||
for is_string_group, group in itertools.groupby(
|
||||
expression.expressions or expression.flatten(), lambda e: e.is_string
|
||||
expressions or expression.flatten(), lambda e: e.is_string
|
||||
):
|
||||
if is_string_group:
|
||||
new_args.append(exp.Literal.string("".join(string.name for string in group)))
|
||||
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
|
||||
else:
|
||||
new_args.extend(group)
|
||||
|
||||
# Ensures we preserve the right concat type, i.e. whether it's "safe" or not
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
|
||||
if len(new_args) == 1 and new_args[0].is_string:
|
||||
return new_args[0]
|
||||
|
||||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
|
||||
return concat_type(expressions=new_args)
|
||||
|
||||
|
||||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue