Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d4fe7bdb16
commit
90988d8258
127 changed files with 73384 additions and 73067 deletions
|
@ -1,12 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import ensure_list, seq_get, subclasses
|
||||
from sqlglot.helper import (
|
||||
ensure_list,
|
||||
is_date_unit,
|
||||
is_iso_date,
|
||||
is_iso_datetime,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -20,10 +26,6 @@ if t.TYPE_CHECKING:
|
|||
]
|
||||
|
||||
|
||||
# Interval units that operate on date components
|
||||
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
|
@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
|
|||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
|
||||
def _is_iso_date(text: str) -> bool:
|
||||
try:
|
||||
datetime.date.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_iso_datetime(text: str) -> bool:
|
||||
try:
|
||||
datetime.datetime.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
date_text = l.name
|
||||
unit = r.text("unit").lower()
|
||||
is_iso_date_ = is_iso_date(date_text)
|
||||
|
||||
is_iso_date = _is_iso_date(date_text)
|
||||
|
||||
if is_iso_date and unit in DATE_UNITS:
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
|
||||
if is_iso_date_ and is_date_unit(unit):
|
||||
return exp.DataType.Type.DATE
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date or _is_iso_datetime(date_text):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
|
||||
if is_iso_date_ or is_iso_datetime(date_text):
|
||||
return exp.DataType.Type.DATETIME
|
||||
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
unit = r.text("unit").lower()
|
||||
if unit not in DATE_UNITS:
|
||||
def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
|
||||
if not is_date_unit(unit):
|
||||
return exp.DataType.Type.DATETIME
|
||||
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Date,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
|
@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Div,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
|
@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DateDiff,
|
||||
exp.DatetimeDiff,
|
||||
exp.DateDiff,
|
||||
exp.Extract,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
|
@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.SafeConcat,
|
||||
exp.SafeDPipe,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
|
@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
|
@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Div: lambda self, e: self._annotate_div(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
|
@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
|
@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
BINARY_COERCIONS: BinaryCoercions = {
|
||||
**swap_all(
|
||||
{
|
||||
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
|
||||
(t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
|
||||
l, r.args.get("unit")
|
||||
)
|
||||
for t in exp.DataType.TEXT_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
|
||||
# text + numeric will yield the numeric type to match most dialects' semantics
|
||||
(text, numeric): lambda l, r: t.cast(
|
||||
exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
|
||||
)
|
||||
for text in exp.DataType.TEXT_TYPES
|
||||
for numeric in exp.DataType.NUMERIC_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
|
||||
l, r.args.get("unit")
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
|
||||
def _annotate_timeunit(
|
||||
self, expression: exp.TimeUnit | exp.DateTrunc
|
||||
) -> exp.TimeUnit | exp.DateTrunc:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if expression.this.type.this in exp.DataType.TEXT_TYPES:
|
||||
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
|
||||
elif (
|
||||
expression.this.type.is_type(exp.DataType.Type.DATE)
|
||||
and expression.text("unit").lower() not in DATE_UNITS
|
||||
):
|
||||
datatype = exp.DataType.Type.DATETIME
|
||||
datatype = _coerce_date_literal(expression.this, expression.unit)
|
||||
elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
|
||||
datatype = _coerce_date(expression.this, expression.unit)
|
||||
else:
|
||||
datatype = expression.this.type
|
||||
datatype = exp.DataType.Type.UNKNOWN
|
||||
|
||||
self._set_type(expression, datatype)
|
||||
return expression
|
||||
|
@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_div(self, expression: exp.Div) -> exp.Div:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
|
||||
|
||||
if (
|
||||
expression.args.get("typed")
|
||||
and left_type in exp.DataType.INTEGER_TYPES
|
||||
and right_type in exp.DataType.INTEGER_TYPES
|
||||
):
|
||||
self._set_type(expression, exp.DataType.Type.BIGINT)
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
|
||||
return expression
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
|
@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
|
||||
return expression
|
||||
|
@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
return node
|
||||
|
||||
|
||||
COERCIBLE_DATE_OPS = (
|
||||
exp.Add,
|
||||
exp.Sub,
|
||||
exp.EQ,
|
||||
exp.NEQ,
|
||||
exp.GT,
|
||||
exp.GTE,
|
||||
exp.LT,
|
||||
exp.LTE,
|
||||
exp.NullSafeEQ,
|
||||
exp.NullSafeNEQ,
|
||||
)
|
||||
|
||||
|
||||
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Binary):
|
||||
if isinstance(node, COERCIBLE_DATE_OPS):
|
||||
_coerce_date(node.left, node.right)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
|
@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
*exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
|
||||
elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
|
||||
_coerce_timeunit_arg(node.this, node.unit)
|
||||
elif isinstance(node, exp.DateDiff):
|
||||
_coerce_datediff_args(node)
|
||||
|
||||
return node
|
||||
|
||||
|
@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
||||
def ensure_bools(
|
||||
expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
|
||||
) -> exp.Expression:
|
||||
if isinstance(expression, exp.Connector):
|
||||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)) or (
|
||||
replace_func(expression.left)
|
||||
replace_func(expression.right)
|
||||
elif isinstance(expression, exp.Not):
|
||||
replace_func(expression.this)
|
||||
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
|
||||
isinstance(expression, exp.If)
|
||||
and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
|
||||
elif isinstance(expression, exp.If) and not (
|
||||
isinstance(expression.parent, exp.Case) and expression.parent.this
|
||||
):
|
||||
_replace_int_predicate(expression.this)
|
||||
replace_func(expression.this)
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
replace_func(expression.this)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if isinstance(b, exp.Interval):
|
||||
a = _coerce_timeunit_arg(a, b.unit)
|
||||
if (
|
||||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and b.type
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
and b.type.this
|
||||
not in (
|
||||
exp.DataType.Type.DATE,
|
||||
exp.DataType.Type.INTERVAL,
|
||||
)
|
||||
):
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
|
||||
|
||||
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
if not arg.type:
|
||||
return arg
|
||||
|
||||
if arg.type.this in exp.DataType.TEXT_TYPES:
|
||||
date_text = arg.name
|
||||
is_iso_date_ = is_iso_date(date_text)
|
||||
|
||||
if is_iso_date_ and is_date_unit(unit):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date_ or is_iso_datetime(date_text):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
|
||||
return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def _coerce_datediff_args(node: exp.DateDiff) -> None:
|
||||
for e in (node.this, node.expression):
|
||||
if e.type.this not in exp.DataType.TEMPORAL_TYPES:
|
||||
e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
# this was originally designed for presto, there is a similar transform for tsql
|
||||
# this is different in that it only operates on int types, this is because
|
||||
# presto has a boolean type whereas tsql doesn't (people use bits)
|
||||
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
expression.replace(expression.neq(0))
|
||||
|
|
|
@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
and inner_select.args.get("where")
|
||||
and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
|
||||
and from_or_join.side in ("FULL", "LEFT", "RIGHT")
|
||||
)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.From)
|
||||
and inner_select.args.get("where")
|
||||
and any(
|
||||
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
|
||||
j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
|
||||
)
|
||||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
|
|
|
@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
|
||||
|
||||
|
@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None):
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
if isinstance(expression, str):
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
|
|
|
@ -42,8 +42,8 @@ RULES = (
|
|||
def optimize(
|
||||
expression: str | exp.Expression,
|
||||
schema: t.Optional[dict | Schema] = None,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
dialect: DialectType = None,
|
||||
rules: t.Sequence[t.Callable] = RULES,
|
||||
**kwargs,
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot._typing import E
|
|||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
@ -58,7 +58,7 @@ def qualify_columns(
|
|||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_qualify_outputs(scope)
|
||||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
ordereds = order.expressions
|
||||
for ordered, new_expression in zip(
|
||||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds)),
|
||||
_expand_positional_references(scope, (o.this for o in ordereds), alias=True),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
|
@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
|
||||
new_nodes = []
|
||||
def _expand_positional_references(
|
||||
scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
|
||||
) -> t.List[exp.Expression]:
|
||||
new_nodes: t.List[exp.Expression] = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
|
||||
select = _select_by_pos(scope, t.cast(exp.Literal, node))
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
new_nodes.append(node)
|
||||
if alias:
|
||||
new_nodes.append(exp.column(select.args["alias"].copy()))
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
scope.clear_cache()
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
else:
|
||||
new_nodes.append(node)
|
||||
|
||||
|
@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources and (
|
||||
not scope.parent or column_table not in scope.parent.sources
|
||||
not scope.parent
|
||||
or column_table not in scope.parent.sources
|
||||
or not scope.is_correlated_subquery
|
||||
):
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
@ -381,15 +389,18 @@ def _expand_stars(
|
|||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
table_id = id(table)
|
||||
columns_to_exclude = except_columns.get(table_id) or set()
|
||||
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
for name in implicit_columns + pivot_output_columns
|
||||
if name not in columns_to_exclude
|
||||
)
|
||||
continue
|
||||
|
||||
table_id = id(table)
|
||||
for name in columns:
|
||||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
if name in coalesced_columns:
|
||||
|
@ -406,7 +417,7 @@ def _expand_stars(
|
|||
copy=False,
|
||||
)
|
||||
)
|
||||
elif name not in except_columns.get(table_id, set()):
|
||||
elif name not in columns_to_exclude:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
column = exp.column(name, table=table)
|
||||
new_selections.append(
|
||||
|
@ -448,10 +459,16 @@ def _add_replace_columns(
|
|||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope: Scope) -> None:
|
||||
def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
if isinstance(scope_or_expression, exp.Expression):
|
||||
scope = build_scope(scope_or_expression)
|
||||
if not isinstance(scope, Scope):
|
||||
return
|
||||
else:
|
||||
scope = scope_or_expression
|
||||
|
||||
new_selections = []
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
):
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
|
@ -10,9 +13,10 @@ from sqlglot.schema import Schema
|
|||
|
||||
def qualify_tables(
|
||||
expression: E,
|
||||
db: t.Optional[str] = None,
|
||||
catalog: t.Optional[str] = None,
|
||||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
schema: t.Optional[Schema] = None,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
|
||||
|
@ -33,11 +37,14 @@ def qualify_tables(
|
|||
db: Database name
|
||||
catalog: Catalog name
|
||||
schema: A schema to populate
|
||||
dialect: The dialect to parse catalog and schema into.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
"""
|
||||
next_alias_name = name_sequence("_q_")
|
||||
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
||||
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
|
@ -61,9 +68,9 @@ def qualify_tables(
|
|||
if isinstance(source, exp.Table):
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
source.set("db", db)
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
source.set("catalog", catalog)
|
||||
|
||||
if not source.alias:
|
||||
# Mutates the source by attaching an alias to it
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import typing as t
|
||||
|
|
|
@ -507,6 +507,9 @@ def simplify_literals(expression, root=True):
|
|||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
|
||||
if type(expression) in INVERSE_DATE_OPS:
|
||||
return _simplify_binary(expression, expression.this, expression.interval()) or expression
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
num_a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
num_b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(a - b)
|
||||
return exp.Literal.number(num_a + num_b)
|
||||
if isinstance(expression, exp.Mul):
|
||||
return exp.Literal.number(a * b)
|
||||
return exp.Literal.number(num_a * num_b)
|
||||
|
||||
# We only simplify Sub, Div if a and b have the same parent because they're not associative
|
||||
if isinstance(expression, exp.Sub):
|
||||
return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
|
||||
if isinstance(expression, exp.Div):
|
||||
# engines have differing int div behavior so intdiv is not safe
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
|
||||
return None
|
||||
return exp.Literal.number(a / b)
|
||||
return exp.Literal.number(num_a / num_b)
|
||||
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
boolean = eval_boolean(expression, num_a, num_b)
|
||||
|
||||
if boolean:
|
||||
return boolean
|
||||
|
@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b):
|
|||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
elif _is_date_literal(a) and _is_date_literal(b):
|
||||
if isinstance(expression, exp.Predicate):
|
||||
a, b = extract_date(a), extract_date(b)
|
||||
boolean = eval_boolean(expression, a, b)
|
||||
if boolean:
|
||||
return boolean
|
||||
|
||||
return None
|
||||
|
||||
|
@ -590,6 +601,11 @@ def simplify_parens(expression):
|
|||
return expression
|
||||
|
||||
|
||||
NONNULL_CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
)
|
||||
|
||||
CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
|
@ -597,11 +613,19 @@ CONSTANTS = (
|
|||
)
|
||||
|
||||
|
||||
def _is_nonnull_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def _is_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def simplify_coalesce(expression):
|
||||
# COALESCE(x) -> x
|
||||
if (
|
||||
isinstance(expression, exp.Coalesce)
|
||||
and not expression.expressions
|
||||
and (not expression.expressions or _is_nonnull_constant(expression.this))
|
||||
# COALESCE is also used as a Spark partitioning hint
|
||||
and not isinstance(expression.parent, exp.Hint)
|
||||
):
|
||||
|
@ -621,12 +645,12 @@ def simplify_coalesce(expression):
|
|||
|
||||
# This transformation is valid for non-constants,
|
||||
# but it really only does anything if they are both constants.
|
||||
if not isinstance(other, CONSTANTS):
|
||||
if not _is_constant(other):
|
||||
return expression
|
||||
|
||||
# Find the first constant arg
|
||||
for arg_index, arg in enumerate(coalesce.expressions):
|
||||
if isinstance(arg, CONSTANTS):
|
||||
if _is_constant(other):
|
||||
break
|
||||
else:
|
||||
return expression
|
||||
|
@ -656,7 +680,6 @@ def simplify_coalesce(expression):
|
|||
|
||||
|
||||
CONCATS = (exp.Concat, exp.DPipe)
|
||||
SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
|
||||
|
||||
|
||||
def simplify_concat(expression):
|
||||
|
@ -672,10 +695,15 @@ def simplify_concat(expression):
|
|||
sep_expr, *expressions = expression.expressions
|
||||
sep = sep_expr.name
|
||||
concat_type = exp.ConcatWs
|
||||
args = {}
|
||||
else:
|
||||
expressions = expression.expressions
|
||||
sep = ""
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
concat_type = exp.Concat
|
||||
args = {
|
||||
"safe": expression.args.get("safe"),
|
||||
"coalesce": expression.args.get("coalesce"),
|
||||
}
|
||||
|
||||
new_args = []
|
||||
for is_string_group, group in itertools.groupby(
|
||||
|
@ -692,7 +720,7 @@ def simplify_concat(expression):
|
|||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
|
||||
return concat_type(expressions=new_args)
|
||||
return concat_type(expressions=new_args, **args)
|
||||
|
||||
|
||||
def simplify_conditionals(expression):
|
||||
|
@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
|
|||
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if isinstance(cast, exp.Cast):
|
||||
to = cast.to
|
||||
elif isinstance(cast, exp.TsOrDsToDate):
|
||||
elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
|
||||
to = exp.DataType.build(exp.DataType.Type.DATE)
|
||||
else:
|
||||
return None
|
||||
|
@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool:
|
|||
|
||||
|
||||
def extract_interval(expression):
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
|
||||
try:
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
return interval(unit, n)
|
||||
except (UnsupportedUnit, ModuleNotFoundError):
|
||||
except (UnsupportedUnit, ModuleNotFoundError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1099,8 +1126,6 @@ GEN_MAP = {
|
|||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.DPipe: lambda e: _binary(e, "||"),
|
||||
exp.SafeDPipe: lambda e: _binary(e, "||"),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue