Merging upstream version 25.32.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
160ab5bf81
commit
02152e9ba6
74 changed files with 2284 additions and 1814 deletions
|
@ -4,10 +4,12 @@ import itertools
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
|
||||
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
|
||||
"""Converts a sql expression into a standard form.
|
||||
|
||||
This method relies on annotate_types because many of the
|
||||
|
@ -17,10 +19,12 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
expression: The expression to canonicalize.
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
|
@ -68,11 +72,11 @@ COERCIBLE_DATE_OPS = (
|
|||
)
|
||||
|
||||
|
||||
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||
def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
|
||||
if isinstance(node, COERCIBLE_DATE_OPS):
|
||||
_coerce_date(node.left, node.right)
|
||||
_coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
_coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
|
||||
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
|
||||
*exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
|
@ -128,17 +132,48 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
def _coerce_date(
|
||||
a: exp.Expression,
|
||||
b: exp.Expression,
|
||||
promote_to_inferred_datetime_type: bool,
|
||||
) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if isinstance(b, exp.Interval):
|
||||
a = _coerce_timeunit_arg(a, b.unit)
|
||||
|
||||
a_type = a.type
|
||||
if (
|
||||
a.type
|
||||
and a.type.this in exp.DataType.TEMPORAL_TYPES
|
||||
and b.type
|
||||
and b.type.this in exp.DataType.TEXT_TYPES
|
||||
not a_type
|
||||
or a_type.this not in exp.DataType.TEMPORAL_TYPES
|
||||
or not b.type
|
||||
or b.type.this not in exp.DataType.TEXT_TYPES
|
||||
):
|
||||
_replace_cast(b, exp.DataType.Type.DATETIME)
|
||||
continue
|
||||
|
||||
if promote_to_inferred_datetime_type:
|
||||
if b.is_string:
|
||||
date_text = b.name
|
||||
if is_iso_date(date_text):
|
||||
b_type = exp.DataType.Type.DATE
|
||||
elif is_iso_datetime(date_text):
|
||||
b_type = exp.DataType.Type.DATETIME
|
||||
else:
|
||||
b_type = a_type.this
|
||||
else:
|
||||
# If b is not a datetime string, we conservatively promote it to a DATETIME,
|
||||
# in order to ensure there are no surprising truncations due to downcasting
|
||||
b_type = exp.DataType.Type.DATETIME
|
||||
|
||||
target_type = (
|
||||
b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
|
||||
)
|
||||
else:
|
||||
target_type = a_type
|
||||
|
||||
if target_type != a_type:
|
||||
_replace_cast(a, target_type)
|
||||
|
||||
_replace_cast(b, target_type)
|
||||
|
||||
|
||||
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
|
@ -168,7 +203,7 @@ def _coerce_datediff_args(node: exp.DateDiff) -> None:
|
|||
e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
||||
def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
|
|
|
@ -524,7 +524,9 @@ def _expand_struct_stars(
|
|||
this = field.this.copy()
|
||||
root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
|
||||
new_column = exp.column(
|
||||
t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
|
||||
t.cast(exp.Identifier, root),
|
||||
table=dot_column.args.get("table"),
|
||||
fields=t.cast(t.List[exp.Identifier], parts),
|
||||
)
|
||||
new_selections.append(alias(new_column, this, copy=False))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue