1
0
Fork 0

Merging upstream version 18.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:58:22 +01:00
parent 985db29269
commit 53cf4a81a6
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
124 changed files with 60313 additions and 50346 deletions

View file

@ -1,2 +1,9 @@
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
from sqlglot.optimizer.scope import (
Scope,
build_scope,
find_all_in_scope,
find_in_scope,
traverse_scope,
walk_in_scope,
)

View file

@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
NESTED_TYPES = {
exp.DataType.Type.ARRAY,
}
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
if isinstance(type2, exp.DataType):
type2 = type2.this
) -> exp.DataType | exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
if exp.DataType.Type.NULL in (type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1_value, type2_value):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
def _annotate_by_args(
self, expression: E, *args: str, promote: bool = False, array: bool = False
) -> E:
self._annotate_args(expression)
expressions: t.List[exp.Expression] = []
@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
if array:
expression.type = exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
)
return expression

View file

@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
to_replace = scope.expression.parent.unwrap()
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
table.set("joins", to_replace.args.get("joins"))
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
table.set("joins", parent.args.get("joins"))
to_replace.replace(table)
parent.replace(table)
return cte

View file

@ -72,8 +72,13 @@ def normalize(expression):
if not any(join.args.get(k) for k in JOIN_ATTRS):
join.set("kind", "CROSS")
if join.kind != "CROSS":
if join.kind == "CROSS":
join.set("on", None)
else:
join.set("kind", None)
if not join.args.get("on") and not join.args.get("using"):
join.set("on", exp.true())
return expression

View file

@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, scope, scope_ref_count):
@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def nodes_for_predicate(predicate, sources, scope_ref_count):

View file

@ -6,7 +6,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
from sqlglot.helper import ensure_collection, find_new_name
logger = logging.getLogger("sqlglot")
@ -141,38 +141,10 @@ class Scope:
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
"""
@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)
def find_all_in_scope(expression, expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression
def find_in_scope(expression, expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

View file

@ -69,10 +69,10 @@ def simplify(expression):
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
node = simplify_coalesce(node)
if root:
expression.replace(node)
@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True):
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
@ -430,13 +431,14 @@ def simplify_parens(expression):
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(this, exp.Predicate)
or isinstance(parent, exp.Paren)
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
return expression.this
return this
return expression
@ -488,18 +490,20 @@ def simplify_coalesce(expression):
coalesce = coalesce if coalesce.expressions else coalesce.this
# This expression is more complex than when we started, but it will get simplified further
return exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
return exp.paren(
exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
)
)
@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True):
for b in queue:
result = simplifier(expression, a, b)
if result:
if result and result is not expression:
queue.remove(b)
queue.appendleft(result)
break