Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
985db29269
commit
53cf4a81a6
124 changed files with 60313 additions and 50346 deletions
|
@ -1,2 +1,9 @@
|
|||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
|
||||
from sqlglot.optimizer.scope import (
|
||||
Scope,
|
||||
build_scope,
|
||||
find_all_in_scope,
|
||||
find_in_scope,
|
||||
traverse_scope,
|
||||
walk_in_scope,
|
||||
)
|
||||
|
|
|
@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for expr_type in expressions
|
||||
},
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
|
@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
exp.DataType.Type.ARRAY,
|
||||
}
|
||||
|
||||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
|
@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType.Type:
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if isinstance(type1, exp.DataType):
|
||||
type1 = type1.this
|
||||
if isinstance(type2, exp.DataType):
|
||||
type2 = type2.this
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
if exp.DataType.Type.NULL in (type1, type2):
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1_value, type2_value):
|
||||
return exp.DataType.Type.NULL
|
||||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
|
||||
if type1_value in self.NESTED_TYPES:
|
||||
return type1
|
||||
if type2_value in self.NESTED_TYPES:
|
||||
return type2
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
|
||||
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
|
@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
|
||||
def _annotate_by_args(
|
||||
self, expression: E, *args: str, promote: bool = False, array: bool = False
|
||||
) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
expressions: t.List[exp.Expression] = []
|
||||
|
@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
elif expression.type.this in exp.DataType.FLOAT_TYPES:
|
||||
expression.type = exp.DataType.Type.DOUBLE
|
||||
|
||||
if array:
|
||||
expression.type = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
|
|||
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
||||
return None
|
||||
|
||||
parent = scope.expression.parent
|
||||
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
|
||||
to_replace = scope.expression.parent.unwrap()
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
|
||||
table.set("joins", to_replace.args.get("joins"))
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||
table.set("joins", parent.args.get("joins"))
|
||||
to_replace.replace(table)
|
||||
|
||||
parent.replace(table)
|
||||
return cte
|
||||
|
||||
|
||||
|
|
|
@ -72,8 +72,13 @@ def normalize(expression):
|
|||
if not any(join.args.get(k) for k in JOIN_ATTRS):
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
if join.kind != "CROSS":
|
||||
if join.kind == "CROSS":
|
||||
join.set("on", None)
|
||||
else:
|
||||
join.set("kind", None)
|
||||
|
||||
if not join.args.get("on") and not join.args.get("using"):
|
||||
join.set("on", exp.true())
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.scope import build_scope, find_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
|
@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
|
|||
break
|
||||
if isinstance(node, exp.Select):
|
||||
predicate.replace(exp.true())
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
inner_predicate = replace_aliases(node, predicate)
|
||||
if find_in_scope(inner_predicate, exp.AggFunc):
|
||||
node.having(inner_predicate, copy=False)
|
||||
else:
|
||||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
|
@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
if isinstance(node, exp.Join):
|
||||
node.on(predicate, copy=False)
|
||||
elif isinstance(node, exp.Select):
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
inner_predicate = replace_aliases(node, predicate)
|
||||
if find_in_scope(inner_predicate, exp.AggFunc):
|
||||
node.having(inner_predicate, copy=False)
|
||||
else:
|
||||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def nodes_for_predicate(predicate, sources, scope_ref_count):
|
||||
|
|
|
@ -6,7 +6,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.helper import ensure_collection, find_new_name
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -141,38 +141,10 @@ class Scope:
|
|||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(self.find_all(*expression_types, bfs=bfs), None)
|
||||
return find_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
||||
def find_all(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
return find_all_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
|
@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
|
|||
for key in ("joins", "laterals", "pivots"):
|
||||
for arg in node.args.get(key) or []:
|
||||
yield from walk_in_scope(arg, bfs=bfs)
|
||||
|
||||
|
||||
def find_all_in_scope(expression, expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
expression_types (tuple[type]|type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in walk_in_scope(expression, bfs=bfs):
|
||||
if isinstance(expression, tuple(ensure_collection(expression_types))):
|
||||
yield expression
|
||||
|
||||
|
||||
def find_in_scope(expression, expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
expression_types (tuple[type]|type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
|
||||
|
|
|
@ -69,10 +69,10 @@ def simplify(expression):
|
|||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node = simplify_coalesce(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_coalesce(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True):
|
|||
def simplify_literals(expression, root=True):
|
||||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
|
||||
if isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
|
@ -430,13 +431,14 @@ def simplify_parens(expression):
|
|||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(this, exp.Predicate)
|
||||
or isinstance(parent, exp.Paren)
|
||||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
):
|
||||
return expression.this
|
||||
return this
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -488,18 +490,20 @@ def simplify_coalesce(expression):
|
|||
coalesce = coalesce if coalesce.expressions else coalesce.this
|
||||
|
||||
# This expression is more complex than when we started, but it will get simplified further
|
||||
return exp.or_(
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()).not_(copy=False),
|
||||
expression.copy(),
|
||||
return exp.paren(
|
||||
exp.or_(
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()).not_(copy=False),
|
||||
expression.copy(),
|
||||
copy=False,
|
||||
),
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()),
|
||||
type(expression)(this=arg.copy(), expression=other.copy()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
),
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()),
|
||||
type(expression)(this=arg.copy(), expression=other.copy()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True):
|
|||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
|
||||
if result:
|
||||
if result and result is not expression:
|
||||
queue.remove(b)
|
||||
queue.appendleft(result)
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue