Merging upstream version 18.17.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fdf9ca761f
commit
04c9be45a8
90 changed files with 46581 additions and 43319 deletions
|
@ -6,7 +6,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.helper import ensure_list, seq_get, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Bracket: lambda self, e: self._annotate_bracket(e),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
|
@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
|
@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
self._set_type(expression, datatype)
|
||||
return expression
|
||||
|
||||
def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
|
||||
self._annotate_args(expression)
|
||||
|
||||
bracket_arg = expression.expressions[0]
|
||||
this = expression.this
|
||||
|
||||
if isinstance(bracket_arg, exp.Slice):
|
||||
self._set_type(expression, this.type)
|
||||
elif this.type.is_type(exp.DataType.Type.ARRAY):
|
||||
contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
||||
self._set_type(expression, contained_type)
|
||||
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
|
||||
index = this.keys.index(bracket_arg)
|
||||
value = seq_get(this.values, index)
|
||||
value_type = value.type if value else exp.DataType.Type.UNKNOWN
|
||||
self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
|
||||
else:
|
||||
self._set_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -69,7 +69,11 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
|||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
|
||||
elif isinstance(expression, (exp.Where, exp.Having)) or (
|
||||
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
|
||||
isinstance(expression, exp.If)
|
||||
and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
|
||||
):
|
||||
_replace_int_predicate(expression.this)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False):
|
|||
node = uniq_sort(node, generate, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
node = simplify_conditionals(node)
|
||||
|
||||
if constant_propagation:
|
||||
node = propagate_constants(node, root)
|
||||
|
@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
if l.__class__ in INVERSE_DATE_OPS:
|
||||
l = t.cast(exp.IntervalOp, l)
|
||||
a = l.this
|
||||
b = l.interval()
|
||||
else:
|
||||
l = t.cast(exp.Binary, l)
|
||||
a, b = l.left, l.right
|
||||
|
||||
if not a_predicate(a) and b_predicate(b):
|
||||
|
@ -695,6 +698,32 @@ def simplify_concat(expression):
|
|||
return concat_type(expressions=new_args)
|
||||
|
||||
|
||||
def simplify_conditionals(expression):
|
||||
"""Simplifies expressions like IF, CASE if their condition is statically known."""
|
||||
if isinstance(expression, exp.Case):
|
||||
this = expression.this
|
||||
for case in expression.args["ifs"]:
|
||||
cond = case.this
|
||||
if this:
|
||||
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
|
||||
cond = cond.replace(this.pop().eq(cond))
|
||||
|
||||
if always_true(cond):
|
||||
return case.args["true"]
|
||||
|
||||
if always_false(cond):
|
||||
case.pop()
|
||||
if not expression.args["ifs"]:
|
||||
return expression.args.get("default") or exp.null()
|
||||
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
|
||||
if always_true(expression.this):
|
||||
return expression.args["true"]
|
||||
if always_false(expression.this):
|
||||
return expression.args.get("false") or exp.null()
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
||||
|
||||
|
@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
else:
|
||||
return expression
|
||||
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
|
@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
rs = expression.expressions
|
||||
|
||||
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
unit = l.unit.name.lower()
|
||||
|
||||
ranges = []
|
||||
|
@ -852,6 +883,10 @@ def always_true(expression):
|
|||
)
|
||||
|
||||
|
||||
def always_false(expression):
|
||||
return is_false(expression) or is_null(expression)
|
||||
|
||||
|
||||
def is_complement(a, b):
|
||||
return isinstance(b, exp.Not) and b.this == a
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue