1
0
Fork 0

Merging upstream version 18.17.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:09:41 +01:00
parent fdf9ca761f
commit 04c9be45a8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
90 changed files with 46581 additions and 43319 deletions

View file

@ -6,7 +6,7 @@ import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.helper import ensure_list, seq_get, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, datatype)
return expression
def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
self._annotate_args(expression)
bracket_arg = expression.expressions[0]
this = expression.this
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
self._set_type(expression, contained_type)
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
value_type = value.type if value else exp.DataType.Type.UNKNOWN
self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)
return expression

View file

@ -69,7 +69,11 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
elif isinstance(expression, (exp.Where, exp.Having)) or (
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
isinstance(expression, exp.If)
and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
):
_replace_int_predicate(expression.this)
return expression

View file

@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False):
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
if constant_propagation:
node = propagate_constants(node, root)
@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
return expression
if l.__class__ in INVERSE_DATE_OPS:
l = t.cast(exp.IntervalOp, l)
a = l.this
b = l.interval()
else:
l = t.cast(exp.Binary, l)
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
@ -695,6 +698,32 @@ def simplify_concat(expression):
return concat_type(expressions=new_args)
def simplify_conditionals(expression):
"""Simplifies expressions like IF, CASE if their condition is statically known."""
if isinstance(expression, exp.Case):
this = expression.this
for case in expression.args["ifs"]:
cond = case.this
if this:
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
cond = cond.replace(this.pop().eq(cond))
if always_true(cond):
return case.args["true"]
if always_false(cond):
case.pop()
if not expression.args["ifs"]:
return expression.args.get("default") or exp.null()
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
if always_true(expression.this):
return expression.args["true"]
if always_false(expression.this):
return expression.args.get("false") or exp.null()
return expression
DateRange = t.Tuple[datetime.date, datetime.date]
@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
else:
return expression
l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
date = extract_date(r)
@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
rs = expression.expressions
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
ranges = []
@ -852,6 +883,10 @@ def always_true(expression):
)
def always_false(expression):
return is_false(expression) or is_null(expression)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a