1
0
Fork 0

Merging upstream version 23.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:30:28 +01:00
parent ebba7c6a18
commit d26905e4af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
187 changed files with 86502 additions and 71397 deletions

View file

@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Exp,
exp.Ln,
exp.Log,
exp.Log2,
exp.Log10,
exp.Pow,
exp.Quantile,
exp.Round,
@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Dot: lambda self, e: self._annotate_dot(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.GenerateDateArray: lambda self, e: self._annotate_with_type(
e, exp.DataType.build("ARRAY<DATE>")
),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.VarMap: lambda self, e: self._annotate_map(e),
}
NESTED_TYPES = {
@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
for value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this
left_type, right_type = left.type.this, right.type.this # type: ignore
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
self._set_type(expression, exp.DataType.Type.VARCHAR)
@ -484,25 +476,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)
@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
@t.no_type_check
def _annotate_by_args(
self,
@ -510,7 +487,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
*args: str,
promote: bool = False,
array: bool = False,
struct: bool = False,
) -> E:
self._annotate_args(expression)
@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
if struct:
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)
return expression
def _annotate_timeunit(
@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BIGINT)
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
self._set_type(
expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
)
return expression
@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)
# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
return expression.type
def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
self._annotate_args(expression)
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
nested=True,
),
)
return expression
@t.overload
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...
@t.overload
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...
def _annotate_map(self, expression):
self._annotate_args(expression)
keys = expression.args.get("keys")
values = expression.args.get("values")
map_type = exp.DataType(this=exp.DataType.Type.MAP)
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN
if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [key_type, value_type])
map_type.set("nested", True)
self._set_type(expression, map_type)
return expression
def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
self._annotate_args(expression)
map_type = exp.DataType(this=exp.DataType.Type.MAP)
arg = expression.this
if arg.is_type(exp.DataType.Type.STRUCT):
for coldef in arg.type.expressions:
kind = coldef.kind
if kind != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
map_type.set("nested", True)
break
self._set_type(expression, map_type)
return expression

View file

@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
def _canonicalize(expression: exp.Expression) -> exp.Expression:
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
return expression
return expression
return exp.replace_tree(expression, _canonicalize)
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
if (
isinstance(node, (exp.Date, exp.TsOrDsToDate))
and not node.expressions
and not node.args.get("zone")
):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
if not node.type:
@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and a.type.this in exp.DataType.TEMPORAL_TYPES
and b.type
and b.type.this
not in (
exp.DataType.Type.DATE,
exp.DataType.Type.INTERVAL,
)
and b.type.this in exp.DataType.TEXT_TYPES
):
_replace_cast(b, exp.DataType.Type.DATE)
_replace_cast(b, exp.DataType.Type.DATETIME)
def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
for child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(expression.neq(0))

View file

@ -32,7 +32,7 @@ def eliminate_ctes(expression):
cte_node.pop()
# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
if with_node and len(with_node.expressions) <= 0:
with_node.pop()
# Decrement the ref count for all sources this CTE selects from

View file

@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
and not _is_recursive()
and not (inner_select.args.get("order") and outer_scope.is_union)
)

View file

@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue

View file

@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None):
if isinstance(expression, str):
expression = exp.parse_identifier(expression, dialect=dialect)
def _normalize(node: E) -> E:
for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
if not node.meta.get("case_sensitive"):
exp.replace_children(node, _normalize)
node = dialect.normalize_identifier(node)
return node
dialect.normalize_identifier(node)
return _normalize(expression)
return expression

View file

@ -82,13 +82,13 @@ def optimize(
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
optimized = rule(optimized, **rule_kwargs)
return t.cast(exp.Expression, expression)
return optimized

View file

@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, scope, scope_ref_count):
def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
if table not in nodes:
continue

View file

@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
]
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:

View file

@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not node:
return
for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column):
continue
@ -306,7 +306,7 @@ def _expand_positional_references(
else:
select = select.this
if isinstance(select, exp.Literal):
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
new_nodes.append(node)
else:
new_nodes.append(select.copy())
@ -425,7 +425,7 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
columns = columns or scope.outer_column_list
columns = columns or scope.outer_columns
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
):
if selection is None:
break
@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
"""Makes sure all identifiers that need to be quoted are quoted."""
return expression.transform(
Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
)
) # type: ignore
def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:

View file

@ -56,7 +56,7 @@ def qualify_tables(
table.set("catalog", catalog)
if not isinstance(expression, exp.Query):
for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)):
for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)):
if isinstance(node, exp.Table):
_qualify(node)
@ -118,11 +118,11 @@ def qualify_tables(
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
else:
for node, parent, _ in scope.walk():
for node in scope.walk():
if (
isinstance(node, exp.Table)
and not node.alias
and isinstance(parent, (exp.From, exp.Join))
and isinstance(node.parent, (exp.From, exp.Join))
):
# Mutates the table by attaching an alias to it
alias(node, node.name, copy=False, table=True)

View file

@ -8,7 +8,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import ensure_collection, find_new_name
from sqlglot.helper import ensure_collection, find_new_name, seq_get
logger = logging.getLogger("sqlglot")
@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
@ -58,7 +58,7 @@ class Scope:
self,
expression,
sources=None,
outer_column_list=None,
outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
@ -70,7 +70,7 @@ class Scope:
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
@ -119,10 +119,11 @@ class Scope:
self._raw_columns = []
self._join_hints = []
for node, parent, _ in self.walk(bfs=False):
for node in self.walk(bfs=False):
if node is self.expression:
continue
elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
@ -132,10 +133,8 @@ class Scope:
self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
elif _is_derived_table(node) and isinstance(
node.parent, (exp.From, exp.Join, exp.Subquery)
):
self._derived_tables.append(node)
elif isinstance(node, exp.UNWRAPPED_QUERIES):
@ -438,11 +437,21 @@ class Scope:
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
stack = [self]
result = []
while stack:
scope = stack.pop()
result.append(scope)
stack.extend(
itertools.chain(
scope.cte_scopes,
scope.union_scopes,
scope.table_scopes,
scope.subquery_scopes,
)
)
yield from reversed(result)
def ref_count(self):
"""
@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
expression: Expression to traverse
Returns:
list[Scope]: scope instances
A list of the created scope instances
"""
if isinstance(expression, exp.Query) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query)
):
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
# We ignore the DDL expression and build a scope for its query instead
ddl_with = expression.args.get("with")
expression = expression.expression
# If the DDL has CTEs attached, we need to add them to the query, or
# prepend them if the query itself already has CTEs attached to it
if ddl_with:
ddl_with.pop()
query_ctes = expression.ctes
if not query_ctes:
expression.set("with", ddl_with)
else:
expression.args["with"].set("recursive", ddl_with.recursive)
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
if isinstance(expression, exp.Query):
return list(_traverse_scope(Scope(expression)))
return []
@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
expression: Expression to build the scope tree for.
Returns:
Scope: root scope
The root scope
"""
scopes = traverse_scope(expression)
if scopes:
return scopes[-1]
return None
return seq_get(traverse_scope(expression), -1)
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
@ -523,8 +546,6 @@ def _traverse_scope(scope):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
yield from _traverse_udtfs(scope)
elif isinstance(scope.expression, exp.DDL):
yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@ -541,30 +562,38 @@ def _traverse_select(scope):
def _traverse_union(scope):
yield from _traverse_ctes(scope)
prev_scope = None
union_scope_stack = [scope]
expression_stack = [scope.expression.right, scope.expression.left]
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(
scope.expression.left,
outer_column_list=scope.outer_column_list,
while expression_stack:
expression = expression_stack.pop()
union_scope = union_scope_stack[-1]
new_scope = union_scope.branch(
expression,
outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
):
yield left
right = None
for right in _traverse_scope(
scope.branch(
scope.expression.right,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield right
if isinstance(expression, exp.Union):
yield from _traverse_ctes(new_scope)
scope.union_scopes = [left, right]
union_scope_stack.append(new_scope)
expression_stack.extend([expression.right, expression.left])
continue
for scope in _traverse_scope(new_scope):
yield scope
if prev_scope:
union_scope_stack.pop()
union_scope.union_scopes = [prev_scope, scope]
prev_scope = union_scope
yield union_scope
else:
prev_scope = scope
def _traverse_ctes(scope):
@ -588,7 +617,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
outer_column_list=cte.alias_column_names,
outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool:
as it doesn't introduce a new scope. If an alias is present, it shadows all names
under the Subquery, so that's one exception to this rule.
"""
return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES))
return isinstance(expression, exp.Subquery) and bool(
expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)
)
def _traverse_tables(scope):
@ -681,7 +712,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
@ -719,13 +750,13 @@ def _traverse_udtfs(scope):
sources = {}
for expression in expressions:
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
if _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
)
):
yield child_scope
@ -738,18 +769,6 @@ def _traverse_udtfs(scope):
scope.sources.update(sources)
def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None):
# Whenever we set it to True, we exclude a subtree from traversal.
crossed_scope_boundary = False
for node, parent, key in expression.walk(
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
for node in expression.walk(
bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n))
):
crossed_scope_boundary = False
yield node, parent, key
yield node
if node is expression:
continue
if (
isinstance(node, exp.CTE)
or (
isinstance(node, exp.Subquery)
and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
and _is_derived_table(node)
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True):
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
for expression in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression

View file

@ -9,19 +9,25 @@ from decimal import Decimal
import sqlglot
from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
# Final means that an expression should not be simplified
FINAL = "final"
# Value ranges for byte-sized signed/unsigned integers
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(Exception):
pass
@ -63,14 +69,14 @@ def simplify(
group.meta[FINAL] = True
for e in expression.selects:
for node, *_ in e.walk():
for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break
having = expression.args.get("having")
if having:
for node, *_ in having.walk():
for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break
@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False):
r = extract_date(r)
if not r:
return None
# python won't compare date and datetime, but many engines will upcast
l, r = cast_as_datetime(l), cast_as_datetime(r)
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
@ -431,7 +439,7 @@ def propagate_constants(expression, root=True):
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right
@ -544,7 +552,37 @@ def simplify_literals(expression, root=True):
return expression
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
this = _simplify_integer_cast(expr.this)
else:
this = expr.this
if isinstance(expr, exp.Cast) and this.is_int:
num = int(this.name)
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
# engine-dependent
if (
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
) or (
UTINYINT_MIN <= num <= UTINYINT_MAX
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
):
return this
return expr
def _simplify_binary(expression, a, b):
if isinstance(expression, COMPARISONS):
a = _simplify_integer_cast(a)
b = _simplify_integer_cast(b)
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b):
return exp.true() if not_ else exp.false()
if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
elif isinstance(expression, NULL_OK):
return None
elif is_null(a) or is_null(b):
return exp.null()
@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
date, b = extract_date(a), extract_interval(b)
if date and b:
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(a + b)
return date_literal(date + b, extract_type(a))
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(a - b)
return date_literal(date - b, extract_type(a))
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
a, date = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
return date_literal(a + date, extract_type(b))
elif _is_date_literal(a) and _is_date_literal(b):
if isinstance(expression, exp.Predicate):
a, b = extract_date(a), extract_date(b)
@ -618,12 +656,16 @@ def simplify_parens(expression):
this = expression.this
parent = expression.parent
parent_is_predicate = isinstance(parent, exp.Predicate)
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
@ -632,24 +674,12 @@ def simplify_parens(expression):
return expression
NONNULL_CONSTANTS = (
exp.Literal,
exp.Boolean,
)
CONSTANTS = (
exp.Literal,
exp.Boolean,
exp.Null,
)
def _is_nonnull_constant(expression: exp.Expression) -> bool:
return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
def _is_constant(expression: exp.Expression) -> bool:
return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
def simplify_coalesce(expression):
@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti
return floor, floor + interval(unit)
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
def _datetrunc_eq_expression(
left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
left >= date_literal(drange[0]),
left < date_literal(drange[1]),
left >= date_literal(drange[0], target_type),
left < date_literal(drange[1], target_type),
copy=False,
)
def _datetrunc_eq(
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return _datetrunc_eq_expression(left, drange)
return _datetrunc_eq_expression(left, drange, target_type)
def _datetrunc_neq(
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
left: exp.Expression,
date: datetime.date,
unit: str,
dialect: Dialect,
target_type: t.Optional[exp.DataType],
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
return exp.and_(
left < date_literal(drange[0]),
left >= date_literal(drange[1]),
left < date_literal(drange[0], target_type),
left >= date_literal(drange[1], target_type),
copy=False,
)
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, dt, u, d: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.LT: lambda l, dt, u, d, t: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
comparison = expression.__class__
if isinstance(expression, DATETRUNCS):
date = extract_date(expression.this)
this = expression.this
trunc_type = extract_type(this)
date = extract_date(this)
if date and expression.unit:
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
elif comparison not in DATETRUNC_COMPARISONS:
return expression
@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
l = t.cast(exp.DateTrunc, l)
trunc_arg = l.this
unit = l.unit.name.lower()
date = extract_date(r)
if not date:
return expression
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
return (
DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
)
or expression
)
if isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs)
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
)
return expression
@ -954,7 +1006,7 @@ JOINS = {
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
where.parent.set("where", None)
where.pop()
for join in expression.find_all(exp.Join):
if (
always_true(join.args.get("on"))
@ -962,7 +1014,7 @@ def remove_where_true(expression):
and not join.args.get("method")
and (join.side, join.kind) in JOINS
):
join.set("on", None)
join.args["on"].pop()
join.set("side", None)
join.set("kind", "CROSS")
@ -1067,15 +1119,25 @@ def extract_interval(expression):
return None
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
(
def extract_type(*expressions):
target_type = None
for expression in expressions:
target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
if target_type:
break
return target_type
def date_literal(date, target_type=None):
if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
target_type = (
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE
),
)
)
return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str:
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
if expression is None:
return "_"
if is_iterable(expression):
return ",".join(gen(e) for e in expression)
if not isinstance(expression, exp.Expression):
return str(expression)
etype = type(expression)
if etype in GEN_MAP:
return GEN_MAP[etype](expression)
return f"{expression.key} {gen(expression.args.values())}"
return Gen().gen(expression)
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
exp.ILike: lambda e: _binary(e, "ILIKE"),
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
exp.Is: lambda e: _binary(e, "IS"),
exp.Like: lambda e: _binary(e, "LIKE"),
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
exp.LT: lambda e: _binary(e, "<"),
exp.LTE: lambda e: _binary(e, "<="),
exp.Mod: lambda e: _binary(e, "%"),
exp.Mul: lambda e: _binary(e, "*"),
exp.Neg: lambda e: _unary(e, "-"),
exp.NEQ: lambda e: _binary(e, "<>"),
exp.Not: lambda e: _unary(e, "NOT"),
exp.Null: lambda e: "NULL",
exp.Or: lambda e: _binary(e, "OR"),
exp.Paren: lambda e: f"({gen(e.this)})",
exp.Sub: lambda e: _binary(e, "-"),
exp.Subquery: lambda e: f"({gen(e.args.values())})",
exp.Table: lambda e: gen(e.args.values()),
exp.Var: lambda e: e.name,
}
class Gen:
def __init__(self):
self.stack = []
self.sqls = []
def gen(self, expression: exp.Expression) -> str:
self.stack = [expression]
self.sqls.clear()
def _anonymous(e: exp.Anonymous) -> str:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = f'"{this.name}"' if this.quoted else this.name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
while self.stack:
node = self.stack.pop()
if isinstance(node, exp.Expression):
exp_handler_name = f"{node.key}_sql"
if hasattr(self, exp_handler_name):
getattr(self, exp_handler_name)(node)
elif isinstance(node, exp.Func):
self._function(node)
else:
key = node.key.upper()
self.stack.append(f"{key} " if self._args(node) else key)
elif type(node) is list:
for n in reversed(node):
if n is not None:
self.stack.extend((n, ","))
if node:
self.stack.pop()
else:
if node is not None:
self.sqls.append(str(node))
return "".join(self.sqls)
def add_sql(self, e: exp.Add) -> None:
self._binary(e, " + ")
def alias_sql(self, e: exp.Alias) -> None:
self.stack.extend(
(
e.args.get("alias"),
" AS ",
e.args.get("this"),
)
)
return f"{name} {','.join(gen(e) for e in e.expressions)}"
def and_sql(self, e: exp.And) -> None:
self._binary(e, " AND ")
def anonymous_sql(self, e: exp.Anonymous) -> None:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = this.this
name = f'"{name}"' if this.quoted else name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
)
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
self.stack.extend(
(
")",
e.expressions,
"(",
name,
)
)
def between_sql(self, e: exp.Between) -> None:
self.stack.extend(
(
e.args.get("high"),
" AND ",
e.args.get("low"),
" BETWEEN ",
e.this,
)
)
def _unary(e: exp.Unary, op: str) -> str:
return f"{op} {gen(e.this)}"
def boolean_sql(self, e: exp.Boolean) -> None:
self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: exp.Bracket) -> None:
self.stack.extend(
(
"]",
e.expressions,
"[",
e.this,
)
)
def column_sql(self, e: exp.Column) -> None:
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def datatype_sql(self, e: exp.DataType) -> None:
self._args(e, 1)
self.stack.append(f"{e.this.name} ")
def div_sql(self, e: exp.Div) -> None:
self._binary(e, " / ")
def dot_sql(self, e: exp.Dot) -> None:
self._binary(e, ".")
def eq_sql(self, e: exp.EQ) -> None:
self._binary(e, " = ")
def from_sql(self, e: exp.From) -> None:
self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: exp.GT) -> None:
self._binary(e, " > ")
def gte_sql(self, e: exp.GTE) -> None:
self._binary(e, " >= ")
def identifier_sql(self, e: exp.Identifier) -> None:
self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: exp.ILike) -> None:
self._binary(e, " ILIKE ")
def in_sql(self, e: exp.In) -> None:
self.stack.append(")")
self._args(e, 1)
self.stack.extend(
(
"(",
" IN ",
e.this,
)
)
def intdiv_sql(self, e: exp.IntDiv) -> None:
self._binary(e, " DIV ")
def is_sql(self, e: exp.Is) -> None:
self._binary(e, " IS ")
def like_sql(self, e: exp.Like) -> None:
self._binary(e, " Like ")
def literal_sql(self, e: exp.Literal) -> None:
self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: exp.LT) -> None:
self._binary(e, " < ")
def lte_sql(self, e: exp.LTE) -> None:
self._binary(e, " <= ")
def mod_sql(self, e: exp.Mod) -> None:
self._binary(e, " % ")
def mul_sql(self, e: exp.Mul) -> None:
self._binary(e, " * ")
def neg_sql(self, e: exp.Neg) -> None:
self._unary(e, "-")
def neq_sql(self, e: exp.NEQ) -> None:
self._binary(e, " <> ")
def not_sql(self, e: exp.Not) -> None:
self._unary(e, "NOT ")
def null_sql(self, e: exp.Null) -> None:
self.stack.append("NULL")
def or_sql(self, e: exp.Or) -> None:
self._binary(e, " OR ")
def paren_sql(self, e: exp.Paren) -> None:
self.stack.extend(
(
")",
e.this,
"(",
)
)
def sub_sql(self, e: exp.Sub) -> None:
self._binary(e, " - ")
def subquery_sql(self, e: exp.Subquery) -> None:
self._args(e, 2)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
self.stack.extend((")", e.this, "("))
def table_sql(self, e: exp.Table) -> None:
self._args(e, 4)
alias = e.args.get("alias")
if alias:
self.stack.append(alias)
for p in reversed(e.parts):
self.stack.extend((p, "."))
self.stack.pop()
def tablealias_sql(self, e: exp.TableAlias) -> None:
columns = e.columns
if columns:
self.stack.extend((")", columns, "("))
self.stack.extend((e.this, " AS "))
def var_sql(self, e: exp.Var) -> None:
self.stack.append(e.this)
def _binary(self, e: exp.Binary, op: str) -> None:
self.stack.extend((e.expression, op, e.this))
def _unary(self, e: exp.Unary, op: str) -> None:
self.stack.extend((e.this, op))
def _function(self, e: exp.Func) -> None:
self.stack.extend(
(
")",
list(e.args.values()),
"(",
e.sql_name(),
)
)
def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
kvs = []
arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
for k in arg_types or arg_types:
v = node.args.get(k)
if v is not None:
kvs.append([f":{k}", v])
if kvs:
self.stack.append(kvs)
return True
return False

View file

@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
if isinstance(predicate, exp.Binary):
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
if any(node is column for node in predicate.left.walk())
else predicate.left
)
else: