Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebba7c6a18
commit
d26905e4af
187 changed files with 86502 additions and 71397 deletions
|
@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Log2,
|
||||
exp.Log10,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
|
@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Dot: lambda self, e: self._annotate_dot(e),
|
||||
exp.Explode: lambda self, e: self._annotate_explode(e),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
|
||||
e, exp.DataType.build("ARRAY<DATE>")
|
||||
),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Map: lambda self, e: self._annotate_map(e),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
|
||||
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
|
||||
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
|
||||
exp.Struct: lambda self, e: self._annotate_struct(e),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.Timestamp: lambda self, e: self._annotate_with_type(
|
||||
e,
|
||||
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
|
||||
),
|
||||
exp.ToMap: lambda self, e: self._annotate_to_map(e),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Unnest: lambda self, e: self._annotate_unnest(e),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.VarMap: lambda self, e: self._annotate_map(e),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
|
@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
|
@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
def _annotate_args(self, expression: E) -> E:
|
||||
for _, value in expression.iter_expressions():
|
||||
for value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
if type1_value in self.NESTED_TYPES:
|
||||
return type1
|
||||
if type2_value in self.NESTED_TYPES:
|
||||
return type2
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
|
||||
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this
|
||||
left_type, right_type = left.type.this, right.type.this # type: ignore
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
|
@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_unary(self, expression: E) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
|
@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
||||
if expression.is_string:
|
||||
self._set_type(expression, exp.DataType.Type.VARCHAR)
|
||||
|
@ -484,25 +476,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(
|
||||
self,
|
||||
|
@ -510,7 +487,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
*args: str,
|
||||
promote: bool = False,
|
||||
array: bool = False,
|
||||
struct: bool = False,
|
||||
) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
|
@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
),
|
||||
)
|
||||
|
||||
if struct:
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_timeunit(
|
||||
|
@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.BIGINT)
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
|
||||
self._set_type(
|
||||
expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
return expression
|
||||
|
||||
def _annotate_struct_value(
|
||||
self, expression: exp.Expression
|
||||
) -> t.Optional[exp.DataType] | exp.ColumnDef:
|
||||
alias = expression.args.get("alias")
|
||||
if alias:
|
||||
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
|
||||
|
||||
# Case: key = value or key := value
|
||||
if expression.expression:
|
||||
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
|
||||
|
||||
return expression.type
|
||||
|
||||
def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
|
||||
self._annotate_args(expression)
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.STRUCT,
|
||||
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
|
||||
nested=True,
|
||||
),
|
||||
)
|
||||
return expression
|
||||
|
||||
@t.overload
|
||||
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
|
||||
|
||||
@t.overload
|
||||
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
|
||||
|
||||
def _annotate_map(self, expression):
|
||||
self._annotate_args(expression)
|
||||
|
||||
keys = expression.args.get("keys")
|
||||
values = expression.args.get("values")
|
||||
|
||||
map_type = exp.DataType(this=exp.DataType.Type.MAP)
|
||||
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
|
||||
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
||||
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
|
||||
|
||||
if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
|
||||
map_type.set("expressions", [key_type, value_type])
|
||||
map_type.set("nested", True)
|
||||
|
||||
self._set_type(expression, map_type)
|
||||
return expression
|
||||
|
||||
def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
|
||||
self._annotate_args(expression)
|
||||
|
||||
map_type = exp.DataType(this=exp.DataType.Type.MAP)
|
||||
arg = expression.this
|
||||
if arg.is_type(exp.DataType.Type.STRUCT):
|
||||
for coldef in arg.type.expressions:
|
||||
kind = coldef.kind
|
||||
if kind != exp.DataType.Type.UNKNOWN:
|
||||
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
|
||||
map_type.set("nested", True)
|
||||
break
|
||||
|
||||
self._set_type(expression, map_type)
|
||||
return expression
|
||||
|
|
|
@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
def _canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
expression = remove_ascending_order(expression)
|
||||
return expression
|
||||
|
||||
return expression
|
||||
return exp.replace_tree(expression, _canonicalize)
|
||||
|
||||
|
||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||
|
@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
|
||||
if (
|
||||
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
||||
and not node.expressions
|
||||
and not node.args.get("zone")
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
if not node.type:
|
||||
|
@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
a = _coerce_timeunit_arg(a, b.unit)
|
||||
if (
|
||||
a.type
|
||||
and a.type.this == exp.DataType.Type.DATE
|
||||
and a.type.this in exp.DataType.TEMPORAL_TYPES
|
||||
and b.type
|
||||
and b.type.this
|
||||
not in (
|
||||
exp.DataType.Type.DATE,
|
||||
exp.DataType.Type.INTERVAL,
|
||||
)
|
||||
and b.type.this in exp.DataType.TEXT_TYPES
|
||||
):
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
_replace_cast(b, exp.DataType.Type.DATETIME)
|
||||
|
||||
|
||||
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
|
||||
|
@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
|||
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
for child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(expression.neq(0))
|
||||
|
|
|
@ -32,7 +32,7 @@ def eliminate_ctes(expression):
|
|||
cte_node.pop()
|
||||
|
||||
# Pop the entire WITH clause if this is the last CTE
|
||||
if len(with_node.expressions) <= 0:
|
||||
if with_node and len(with_node.expressions) <= 0:
|
||||
with_node.pop()
|
||||
|
||||
# Decrement the ref count for all sources this CTE selects from
|
||||
|
|
|
@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
and not _is_recursive()
|
||||
and not (inner_select.args.get("order") and outer_scope.is_union)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
continue
|
||||
|
|
|
@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None):
|
|||
if isinstance(expression, str):
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
node = dialect.normalize_identifier(node)
|
||||
return node
|
||||
dialect.normalize_identifier(node)
|
||||
|
||||
return _normalize(expression)
|
||||
return expression
|
||||
|
|
|
@ -82,13 +82,13 @@ def optimize(
|
|||
**kwargs,
|
||||
}
|
||||
|
||||
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
optimized = rule(optimized, **rule_kwargs)
|
||||
|
||||
return t.cast(exp.Expression, expression)
|
||||
return optimized
|
||||
|
|
|
@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
|
|||
pushdown_dnf(predicates, sources, scope_ref_count)
|
||||
|
||||
|
||||
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
|
||||
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
|
||||
"""
|
||||
If the predicates are in CNF like form, we can simply replace each block in the parent.
|
||||
"""
|
||||
join_index = join_index or {}
|
||||
for predicate in predicates:
|
||||
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
|
||||
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
|
||||
if isinstance(node, exp.Join):
|
||||
name = node.alias_or_name
|
||||
predicate_tables = exp.column_table_names(predicate, name)
|
||||
|
@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
|
|||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
def pushdown_dnf(predicates, sources, scope_ref_count):
|
||||
"""
|
||||
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
|
||||
Additionally, we can't remove predicates from their original form.
|
||||
|
@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
# pushdown all predicates to their respective nodes
|
||||
for table in sorted(pushdown_tables):
|
||||
for predicate in predicates:
|
||||
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
|
||||
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
|
||||
|
||||
if table not in nodes:
|
||||
continue
|
||||
|
|
|
@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
if scope.expression.args.get("by_name"):
|
||||
referenced_columns[right] = referenced_columns[left]
|
||||
else:
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections
|
||||
or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if remove_unused_selections:
|
||||
|
|
|
@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
|
@ -306,7 +306,7 @@ def _expand_positional_references(
|
|||
else:
|
||||
select = select.this
|
||||
|
||||
if isinstance(select, exp.Literal):
|
||||
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
|
||||
new_nodes.append(node)
|
||||
else:
|
||||
new_nodes.append(select.copy())
|
||||
|
@ -425,7 +425,7 @@ def _expand_stars(
|
|||
raise OptimizeError(f"Unknown table: {table}")
|
||||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
columns = columns or scope.outer_column_list
|
||||
columns = columns or scope.outer_columns
|
||||
|
||||
if pseudocolumns:
|
||||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
|
|||
|
||||
new_selections = []
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
|
||||
):
|
||||
if selection is None:
|
||||
break
|
||||
|
@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
|
|||
"""Makes sure all identifiers that need to be quoted are quoted."""
|
||||
return expression.transform(
|
||||
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
|
||||
|
|
|
@ -56,7 +56,7 @@ def qualify_tables(
|
|||
table.set("catalog", catalog)
|
||||
|
||||
if not isinstance(expression, exp.Query):
|
||||
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
|
||||
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
|
||||
if isinstance(node, exp.Table):
|
||||
_qualify(node)
|
||||
|
||||
|
@ -118,11 +118,11 @@ def qualify_tables(
|
|||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
else:
|
||||
for node, parent, _ in scope.walk():
|
||||
for node in scope.walk():
|
||||
if (
|
||||
isinstance(node, exp.Table)
|
||||
and not node.alias
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and isinstance(node.parent, (exp.From, exp.Join))
|
||||
):
|
||||
# Mutates the table by attaching an alias to it
|
||||
alias(node, node.name, copy=False, table=True)
|
||||
|
|
|
@ -8,7 +8,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import ensure_collection, find_new_name
|
||||
from sqlglot.helper import ensure_collection, find_new_name, seq_get
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -38,11 +38,11 @@ class Scope:
|
|||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
cte_sources (dict[str, Scope]): Sources from CTES
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list for the alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
SELECT * FROM (SELECT ...) AS y(col1, col2)
|
||||
The inner query would have `["col1", "col2"]` for its `outer_column_list`
|
||||
The inner query would have `["col1", "col2"]` for its `outer_columns`
|
||||
parent (Scope): Parent scope
|
||||
scope_type (ScopeType): Type of this scope, relative to it's parent
|
||||
subquery_scopes (list[Scope]): List of all child scopes for subqueries
|
||||
|
@ -58,7 +58,7 @@ class Scope:
|
|||
self,
|
||||
expression,
|
||||
sources=None,
|
||||
outer_column_list=None,
|
||||
outer_columns=None,
|
||||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
|
@ -70,7 +70,7 @@ class Scope:
|
|||
self.cte_sources = cte_sources or {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.sources.update(self.cte_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.outer_columns = outer_columns or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
self.subquery_scopes = []
|
||||
|
@ -119,10 +119,11 @@ class Scope:
|
|||
self._raw_columns = []
|
||||
self._join_hints = []
|
||||
|
||||
for node, parent, _ in self.walk(bfs=False):
|
||||
for node in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
|
@ -132,10 +133,8 @@ class Scope:
|
|||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
elif _is_derived_table(node) and isinstance(
|
||||
node.parent, (exp.From, exp.Join, exp.Subquery)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.UNWRAPPED_QUERIES):
|
||||
|
@ -438,11 +437,21 @@ class Scope:
|
|||
Yields:
|
||||
Scope: scope instances in depth-first-search post-order
|
||||
"""
|
||||
for child_scope in itertools.chain(
|
||||
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
|
||||
):
|
||||
yield from child_scope.traverse()
|
||||
yield self
|
||||
stack = [self]
|
||||
result = []
|
||||
while stack:
|
||||
scope = stack.pop()
|
||||
result.append(scope)
|
||||
stack.extend(
|
||||
itertools.chain(
|
||||
scope.cte_scopes,
|
||||
scope.union_scopes,
|
||||
scope.table_scopes,
|
||||
scope.subquery_scopes,
|
||||
)
|
||||
)
|
||||
|
||||
yield from reversed(result)
|
||||
|
||||
def ref_count(self):
|
||||
"""
|
||||
|
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to traverse
|
||||
expression: Expression to traverse
|
||||
|
||||
Returns:
|
||||
list[Scope]: scope instances
|
||||
A list of the created scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.Query) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
|
||||
):
|
||||
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
|
||||
# We ignore the DDL expression and build a scope for its query instead
|
||||
ddl_with = expression.args.get("with")
|
||||
expression = expression.expression
|
||||
|
||||
# If the DDL has CTEs attached, we need to add them to the query, or
|
||||
# prepend them if the query itself already has CTEs attached to it
|
||||
if ddl_with:
|
||||
ddl_with.pop()
|
||||
query_ctes = expression.ctes
|
||||
if not query_ctes:
|
||||
expression.set("with", ddl_with)
|
||||
else:
|
||||
expression.args["with"].set("recursive", ddl_with.recursive)
|
||||
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
|
||||
|
||||
if isinstance(expression, exp.Query):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
|||
Build a scope tree.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression): expression to build the scope tree for
|
||||
expression: Expression to build the scope tree for.
|
||||
|
||||
Returns:
|
||||
Scope: root scope
|
||||
The root scope
|
||||
"""
|
||||
scopes = traverse_scope(expression)
|
||||
if scopes:
|
||||
return scopes[-1]
|
||||
return None
|
||||
return seq_get(traverse_scope(expression), -1)
|
||||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
|
@ -523,8 +546,6 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(scope.expression, exp.DDL):
|
||||
yield from _traverse_ddl(scope)
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
|
@ -541,30 +562,38 @@ def _traverse_select(scope):
|
|||
|
||||
|
||||
def _traverse_union(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
prev_scope = None
|
||||
union_scope_stack = [scope]
|
||||
expression_stack = [scope.expression.right, scope.expression.left]
|
||||
|
||||
# The last scope to be yield should be the top most scope
|
||||
left = None
|
||||
for left in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.left,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
while expression_stack:
|
||||
expression = expression_stack.pop()
|
||||
union_scope = union_scope_stack[-1]
|
||||
|
||||
new_scope = union_scope.branch(
|
||||
expression,
|
||||
outer_columns=union_scope.outer_columns,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield left
|
||||
|
||||
right = None
|
||||
for right in _traverse_scope(
|
||||
scope.branch(
|
||||
scope.expression.right,
|
||||
outer_column_list=scope.outer_column_list,
|
||||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
):
|
||||
yield right
|
||||
if isinstance(expression, exp.Union):
|
||||
yield from _traverse_ctes(new_scope)
|
||||
|
||||
scope.union_scopes = [left, right]
|
||||
union_scope_stack.append(new_scope)
|
||||
expression_stack.extend([expression.right, expression.left])
|
||||
continue
|
||||
|
||||
for scope in _traverse_scope(new_scope):
|
||||
yield scope
|
||||
|
||||
if prev_scope:
|
||||
union_scope_stack.pop()
|
||||
union_scope.union_scopes = [prev_scope, scope]
|
||||
prev_scope = union_scope
|
||||
|
||||
yield union_scope
|
||||
else:
|
||||
prev_scope = scope
|
||||
|
||||
|
||||
def _traverse_ctes(scope):
|
||||
|
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
|
|||
scope.branch(
|
||||
cte.this,
|
||||
cte_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
outer_columns=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
):
|
||||
|
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
|
|||
as it doesn't introduce a new scope. If an alias is present, it shadows all names
|
||||
under the Subquery, so that's one exception to this rule.
|
||||
"""
|
||||
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
|
||||
return isinstance(expression, exp.Subquery) and bool(
|
||||
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
|
||||
)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
|
@ -681,7 +712,7 @@ def _traverse_tables(scope):
|
|||
scope.branch(
|
||||
expression,
|
||||
lateral_sources=lateral_sources,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
scope_type=scope_type,
|
||||
)
|
||||
):
|
||||
|
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
|
|||
|
||||
sources = {}
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
|
||||
if _is_derived_table(expression):
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
outer_columns=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
|
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_ddl(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
||||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
|
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
|
|||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
crossed_scope_boundary = False
|
||||
|
||||
for node, parent, key in expression.walk(
|
||||
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
|
||||
for node in expression.walk(
|
||||
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
|
||||
):
|
||||
crossed_scope_boundary = False
|
||||
|
||||
yield node, parent, key
|
||||
yield node
|
||||
|
||||
if node is expression:
|
||||
continue
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and _is_derived_table(node)
|
||||
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
|
||||
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.UNWRAPPED_QUERIES)
|
||||
):
|
||||
crossed_scope_boundary = True
|
||||
|
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
|
|||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in walk_in_scope(expression, bfs=bfs):
|
||||
for expression in walk_in_scope(expression, bfs=bfs):
|
||||
if isinstance(expression, tuple(ensure_collection(expression_types))):
|
||||
yield expression
|
||||
|
||||
|
|
|
@ -9,19 +9,25 @@ from decimal import Decimal
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import Dialect, exp
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
|
||||
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
|
||||
]
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
# Value ranges for byte-sized signed/unsigned integers
|
||||
TINYINT_MIN = -128
|
||||
TINYINT_MAX = 127
|
||||
UTINYINT_MIN = 0
|
||||
UTINYINT_MAX = 255
|
||||
|
||||
|
||||
class UnsupportedUnit(Exception):
|
||||
pass
|
||||
|
@ -63,14 +69,14 @@ def simplify(
|
|||
group.meta[FINAL] = True
|
||||
|
||||
for e in expression.selects:
|
||||
for node, *_ in e.walk():
|
||||
for node in e.walk():
|
||||
if node in groups:
|
||||
e.meta[FINAL] = True
|
||||
break
|
||||
|
||||
having = expression.args.get("having")
|
||||
if having:
|
||||
for node, *_ in having.walk():
|
||||
for node in having.walk():
|
||||
if node in groups:
|
||||
having.meta[FINAL] = True
|
||||
break
|
||||
|
@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
r = extract_date(r)
|
||||
if not r:
|
||||
return None
|
||||
# python won't compare date and datetime, but many engines will upcast
|
||||
l, r = cast_as_datetime(l), cast_as_datetime(r)
|
||||
|
||||
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
||||
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
|
||||
|
@ -431,7 +439,7 @@ def propagate_constants(expression, root=True):
|
|||
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
|
||||
):
|
||||
constant_mapping = {}
|
||||
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
|
||||
for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
|
||||
if isinstance(expr, exp.EQ):
|
||||
l, r = expr.left, expr.right
|
||||
|
||||
|
@ -544,7 +552,37 @@ def simplify_literals(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
|
||||
|
||||
|
||||
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
|
||||
this = _simplify_integer_cast(expr.this)
|
||||
else:
|
||||
this = expr.this
|
||||
|
||||
if isinstance(expr, exp.Cast) and this.is_int:
|
||||
num = int(this.name)
|
||||
|
||||
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
||||
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
||||
# engine-dependent
|
||||
if (
|
||||
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
|
||||
) or (
|
||||
UTINYINT_MIN <= num <= UTINYINT_MAX
|
||||
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
|
||||
):
|
||||
return this
|
||||
|
||||
return expr
|
||||
|
||||
|
||||
def _simplify_binary(expression, a, b):
|
||||
if isinstance(expression, COMPARISONS):
|
||||
a = _simplify_integer_cast(a)
|
||||
b = _simplify_integer_cast(b)
|
||||
|
||||
if isinstance(expression, exp.Is):
|
||||
if isinstance(b, exp.Not):
|
||||
c = b.this
|
||||
|
@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.true() if not_ else exp.false()
|
||||
if is_null(a):
|
||||
return exp.false() if not_ else exp.true()
|
||||
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
|
||||
elif isinstance(expression, NULL_OK):
|
||||
return None
|
||||
elif is_null(a) or is_null(b):
|
||||
return exp.null()
|
||||
|
@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b):
|
|||
if boolean:
|
||||
return boolean
|
||||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
date, b = extract_date(a), extract_interval(b)
|
||||
if date and b:
|
||||
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
||||
return date_literal(a + b)
|
||||
return date_literal(date + b, extract_type(a))
|
||||
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
||||
return date_literal(a - b)
|
||||
return date_literal(date - b, extract_type(a))
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
a, date = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
return date_literal(a + date, extract_type(b))
|
||||
elif _is_date_literal(a) and _is_date_literal(b):
|
||||
if isinstance(expression, exp.Predicate):
|
||||
a, b = extract_date(a), extract_date(b)
|
||||
|
@ -618,12 +656,16 @@ def simplify_parens(expression):
|
|||
|
||||
this = expression.this
|
||||
parent = expression.parent
|
||||
parent_is_predicate = isinstance(parent, exp.Predicate)
|
||||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
|
@ -632,24 +674,12 @@ def simplify_parens(expression):
|
|||
return expression
|
||||
|
||||
|
||||
NONNULL_CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
)
|
||||
|
||||
CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
exp.Null,
|
||||
)
|
||||
|
||||
|
||||
def _is_nonnull_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def _is_constant(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
|
||||
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
|
||||
|
||||
|
||||
def simplify_coalesce(expression):
|
||||
|
@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti
|
|||
return floor, floor + interval(unit)
|
||||
|
||||
|
||||
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
|
||||
def _datetrunc_eq_expression(
|
||||
left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
|
||||
) -> exp.Expression:
|
||||
"""Get the logical expression for a date range"""
|
||||
return exp.and_(
|
||||
left >= date_literal(drange[0]),
|
||||
left < date_literal(drange[1]),
|
||||
left >= date_literal(drange[0], target_type),
|
||||
left < date_literal(drange[1], target_type),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
def _datetrunc_eq(
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
left: exp.Expression,
|
||||
date: datetime.date,
|
||||
unit: str,
|
||||
dialect: Dialect,
|
||||
target_type: t.Optional[exp.DataType],
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return _datetrunc_eq_expression(left, drange)
|
||||
return _datetrunc_eq_expression(left, drange, target_type)
|
||||
|
||||
|
||||
def _datetrunc_neq(
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
left: exp.Expression,
|
||||
date: datetime.date,
|
||||
unit: str,
|
||||
dialect: Dialect,
|
||||
target_type: t.Optional[exp.DataType],
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return exp.and_(
|
||||
left < date_literal(drange[0]),
|
||||
left >= date_literal(drange[1]),
|
||||
left < date_literal(drange[0], target_type),
|
||||
left >= date_literal(drange[1], target_type),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
||||
exp.LT: lambda l, dt, u, d: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
|
||||
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
|
||||
exp.LT: lambda l, dt, u, d, t: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
|
||||
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
|
||||
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
|
||||
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
|
||||
exp.EQ: _datetrunc_eq,
|
||||
exp.NEQ: _datetrunc_neq,
|
||||
}
|
||||
|
@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
comparison = expression.__class__
|
||||
|
||||
if isinstance(expression, DATETRUNCS):
|
||||
date = extract_date(expression.this)
|
||||
this = expression.this
|
||||
trunc_type = extract_type(this)
|
||||
date = extract_date(this)
|
||||
if date and expression.unit:
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
|
||||
elif comparison not in DATETRUNC_COMPARISONS:
|
||||
return expression
|
||||
|
||||
|
@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
return expression
|
||||
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
trunc_arg = l.this
|
||||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
return (
|
||||
DATETRUNC_BINARY_COMPARISONS[comparison](
|
||||
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
|
||||
)
|
||||
or expression
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
||||
|
@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
|
|||
return expression
|
||||
|
||||
ranges = merge_ranges(ranges)
|
||||
target_type = extract_type(l, *rs)
|
||||
|
||||
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
|
||||
return exp.or_(
|
||||
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -954,7 +1006,7 @@ JOINS = {
|
|||
def remove_where_true(expression):
|
||||
for where in expression.find_all(exp.Where):
|
||||
if always_true(where.this):
|
||||
where.parent.set("where", None)
|
||||
where.pop()
|
||||
for join in expression.find_all(exp.Join):
|
||||
if (
|
||||
always_true(join.args.get("on"))
|
||||
|
@ -962,7 +1014,7 @@ def remove_where_true(expression):
|
|||
and not join.args.get("method")
|
||||
and (join.side, join.kind) in JOINS
|
||||
):
|
||||
join.set("on", None)
|
||||
join.args["on"].pop()
|
||||
join.set("side", None)
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
|
@ -1067,15 +1119,25 @@ def extract_interval(expression):
|
|||
return None
|
||||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
(
|
||||
def extract_type(*expressions):
|
||||
target_type = None
|
||||
for expression in expressions:
|
||||
target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
|
||||
if target_type:
|
||||
break
|
||||
|
||||
return target_type
|
||||
|
||||
|
||||
def date_literal(date, target_type=None):
|
||||
if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
|
||||
target_type = (
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return exp.cast(exp.Literal.string(date), target_type)
|
||||
|
||||
|
||||
def interval(unit: str, n: int = 1):
|
||||
|
@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str:
|
|||
Sorting and deduping sql is a necessary step for optimization. Calling the actual
|
||||
generator is expensive so we have a bare minimum sql generator here.
|
||||
"""
|
||||
if expression is None:
|
||||
return "_"
|
||||
if is_iterable(expression):
|
||||
return ",".join(gen(e) for e in expression)
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return str(expression)
|
||||
|
||||
etype = type(expression)
|
||||
if etype in GEN_MAP:
|
||||
return GEN_MAP[etype](expression)
|
||||
return f"{expression.key} {gen(expression.args.values())}"
|
||||
return Gen().gen(expression)
|
||||
|
||||
|
||||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: _anonymous(e),
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
|
||||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
|
||||
exp.ILike: lambda e: _binary(e, "ILIKE"),
|
||||
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
|
||||
exp.Is: lambda e: _binary(e, "IS"),
|
||||
exp.Like: lambda e: _binary(e, "LIKE"),
|
||||
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
|
||||
exp.LT: lambda e: _binary(e, "<"),
|
||||
exp.LTE: lambda e: _binary(e, "<="),
|
||||
exp.Mod: lambda e: _binary(e, "%"),
|
||||
exp.Mul: lambda e: _binary(e, "*"),
|
||||
exp.Neg: lambda e: _unary(e, "-"),
|
||||
exp.NEQ: lambda e: _binary(e, "<>"),
|
||||
exp.Not: lambda e: _unary(e, "NOT"),
|
||||
exp.Null: lambda e: "NULL",
|
||||
exp.Or: lambda e: _binary(e, "OR"),
|
||||
exp.Paren: lambda e: f"({gen(e.this)})",
|
||||
exp.Sub: lambda e: _binary(e, "-"),
|
||||
exp.Subquery: lambda e: f"({gen(e.args.values())})",
|
||||
exp.Table: lambda e: gen(e.args.values()),
|
||||
exp.Var: lambda e: e.name,
|
||||
}
|
||||
class Gen:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
self.sqls = []
|
||||
|
||||
def gen(self, expression: exp.Expression) -> str:
|
||||
self.stack = [expression]
|
||||
self.sqls.clear()
|
||||
|
||||
def _anonymous(e: exp.Anonymous) -> str:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = f'"{this.name}"' if this.quoted else this.name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
while self.stack:
|
||||
node = self.stack.pop()
|
||||
|
||||
if isinstance(node, exp.Expression):
|
||||
exp_handler_name = f"{node.key}_sql"
|
||||
|
||||
if hasattr(self, exp_handler_name):
|
||||
getattr(self, exp_handler_name)(node)
|
||||
elif isinstance(node, exp.Func):
|
||||
self._function(node)
|
||||
else:
|
||||
key = node.key.upper()
|
||||
self.stack.append(f"{key} " if self._args(node) else key)
|
||||
elif type(node) is list:
|
||||
for n in reversed(node):
|
||||
if n is not None:
|
||||
self.stack.extend((n, ","))
|
||||
if node:
|
||||
self.stack.pop()
|
||||
else:
|
||||
if node is not None:
|
||||
self.sqls.append(str(node))
|
||||
|
||||
return "".join(self.sqls)
|
||||
|
||||
def add_sql(self, e: exp.Add) -> None:
|
||||
self._binary(e, " + ")
|
||||
|
||||
def alias_sql(self, e: exp.Alias) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
e.args.get("alias"),
|
||||
" AS ",
|
||||
e.args.get("this"),
|
||||
)
|
||||
)
|
||||
|
||||
return f"{name} {','.join(gen(e) for e in e.expressions)}"
|
||||
def and_sql(self, e: exp.And) -> None:
|
||||
self._binary(e, " AND ")
|
||||
|
||||
def anonymous_sql(self, e: exp.Anonymous) -> None:
|
||||
this = e.this
|
||||
if isinstance(this, str):
|
||||
name = this.upper()
|
||||
elif isinstance(this, exp.Identifier):
|
||||
name = this.this
|
||||
name = f'"{name}"' if this.quoted else name.upper()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
|
||||
)
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
e.expressions,
|
||||
"(",
|
||||
name,
|
||||
)
|
||||
)
|
||||
|
||||
def between_sql(self, e: exp.Between) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
e.args.get("high"),
|
||||
" AND ",
|
||||
e.args.get("low"),
|
||||
" BETWEEN ",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def _unary(e: exp.Unary, op: str) -> str:
|
||||
return f"{op} {gen(e.this)}"
|
||||
def boolean_sql(self, e: exp.Boolean) -> None:
|
||||
self.stack.append("TRUE" if e.this else "FALSE")
|
||||
|
||||
def bracket_sql(self, e: exp.Bracket) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
"]",
|
||||
e.expressions,
|
||||
"[",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def column_sql(self, e: exp.Column) -> None:
|
||||
for p in reversed(e.parts):
|
||||
self.stack.extend((p, "."))
|
||||
self.stack.pop()
|
||||
|
||||
def datatype_sql(self, e: exp.DataType) -> None:
|
||||
self._args(e, 1)
|
||||
self.stack.append(f"{e.this.name} ")
|
||||
|
||||
def div_sql(self, e: exp.Div) -> None:
|
||||
self._binary(e, " / ")
|
||||
|
||||
def dot_sql(self, e: exp.Dot) -> None:
|
||||
self._binary(e, ".")
|
||||
|
||||
def eq_sql(self, e: exp.EQ) -> None:
|
||||
self._binary(e, " = ")
|
||||
|
||||
def from_sql(self, e: exp.From) -> None:
|
||||
self.stack.extend((e.this, "FROM "))
|
||||
|
||||
def gt_sql(self, e: exp.GT) -> None:
|
||||
self._binary(e, " > ")
|
||||
|
||||
def gte_sql(self, e: exp.GTE) -> None:
|
||||
self._binary(e, " >= ")
|
||||
|
||||
def identifier_sql(self, e: exp.Identifier) -> None:
|
||||
self.stack.append(f'"{e.this}"' if e.quoted else e.this)
|
||||
|
||||
def ilike_sql(self, e: exp.ILike) -> None:
|
||||
self._binary(e, " ILIKE ")
|
||||
|
||||
def in_sql(self, e: exp.In) -> None:
|
||||
self.stack.append(")")
|
||||
self._args(e, 1)
|
||||
self.stack.extend(
|
||||
(
|
||||
"(",
|
||||
" IN ",
|
||||
e.this,
|
||||
)
|
||||
)
|
||||
|
||||
def intdiv_sql(self, e: exp.IntDiv) -> None:
|
||||
self._binary(e, " DIV ")
|
||||
|
||||
def is_sql(self, e: exp.Is) -> None:
|
||||
self._binary(e, " IS ")
|
||||
|
||||
def like_sql(self, e: exp.Like) -> None:
|
||||
self._binary(e, " Like ")
|
||||
|
||||
def literal_sql(self, e: exp.Literal) -> None:
|
||||
self.stack.append(f"'{e.this}'" if e.is_string else e.this)
|
||||
|
||||
def lt_sql(self, e: exp.LT) -> None:
|
||||
self._binary(e, " < ")
|
||||
|
||||
def lte_sql(self, e: exp.LTE) -> None:
|
||||
self._binary(e, " <= ")
|
||||
|
||||
def mod_sql(self, e: exp.Mod) -> None:
|
||||
self._binary(e, " % ")
|
||||
|
||||
def mul_sql(self, e: exp.Mul) -> None:
|
||||
self._binary(e, " * ")
|
||||
|
||||
def neg_sql(self, e: exp.Neg) -> None:
|
||||
self._unary(e, "-")
|
||||
|
||||
def neq_sql(self, e: exp.NEQ) -> None:
|
||||
self._binary(e, " <> ")
|
||||
|
||||
def not_sql(self, e: exp.Not) -> None:
|
||||
self._unary(e, "NOT ")
|
||||
|
||||
def null_sql(self, e: exp.Null) -> None:
|
||||
self.stack.append("NULL")
|
||||
|
||||
def or_sql(self, e: exp.Or) -> None:
|
||||
self._binary(e, " OR ")
|
||||
|
||||
def paren_sql(self, e: exp.Paren) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
e.this,
|
||||
"(",
|
||||
)
|
||||
)
|
||||
|
||||
def sub_sql(self, e: exp.Sub) -> None:
|
||||
self._binary(e, " - ")
|
||||
|
||||
def subquery_sql(self, e: exp.Subquery) -> None:
|
||||
self._args(e, 2)
|
||||
alias = e.args.get("alias")
|
||||
if alias:
|
||||
self.stack.append(alias)
|
||||
self.stack.extend((")", e.this, "("))
|
||||
|
||||
def table_sql(self, e: exp.Table) -> None:
|
||||
self._args(e, 4)
|
||||
alias = e.args.get("alias")
|
||||
if alias:
|
||||
self.stack.append(alias)
|
||||
for p in reversed(e.parts):
|
||||
self.stack.extend((p, "."))
|
||||
self.stack.pop()
|
||||
|
||||
def tablealias_sql(self, e: exp.TableAlias) -> None:
|
||||
columns = e.columns
|
||||
|
||||
if columns:
|
||||
self.stack.extend((")", columns, "("))
|
||||
|
||||
self.stack.extend((e.this, " AS "))
|
||||
|
||||
def var_sql(self, e: exp.Var) -> None:
|
||||
self.stack.append(e.this)
|
||||
|
||||
def _binary(self, e: exp.Binary, op: str) -> None:
|
||||
self.stack.extend((e.expression, op, e.this))
|
||||
|
||||
def _unary(self, e: exp.Unary, op: str) -> None:
|
||||
self.stack.extend((e.this, op))
|
||||
|
||||
def _function(self, e: exp.Func) -> None:
|
||||
self.stack.extend(
|
||||
(
|
||||
")",
|
||||
list(e.args.values()),
|
||||
"(",
|
||||
e.sql_name(),
|
||||
)
|
||||
)
|
||||
|
||||
def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
|
||||
kvs = []
|
||||
arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
|
||||
|
||||
for k in arg_types or arg_types:
|
||||
v = node.args.get(k)
|
||||
|
||||
if v is not None:
|
||||
kvs.append([f":{k}", v])
|
||||
if kvs:
|
||||
self.stack.append(kvs)
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
|
|||
if isinstance(predicate, exp.Binary):
|
||||
key = (
|
||||
predicate.right
|
||||
if any(node is column for node, *_ in predicate.left.walk())
|
||||
if any(node is column for node in predicate.left.walk())
|
||||
else predicate.left
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue