Merging upstream version 17.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
f4a8b128b0
commit
bf82c6c1c0
78 changed files with 35859 additions and 34717 deletions
|
@ -85,7 +85,7 @@ def _unique_outputs(scope):
|
|||
grouped_outputs = set()
|
||||
|
||||
unique_outputs = set()
|
||||
for select in scope.selects:
|
||||
for select in scope.expression.selects:
|
||||
output = select.unalias()
|
||||
if output in grouped_expressions:
|
||||
grouped_outputs.add(output)
|
||||
|
@ -105,7 +105,7 @@ def _unique_outputs(scope):
|
|||
|
||||
def _has_single_output_row(scope):
|
||||
return isinstance(scope.expression, exp.Select) and (
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects)
|
||||
all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects)
|
||||
or _is_limit_1(scope)
|
||||
or not scope.expression.args.get("from")
|
||||
)
|
||||
|
|
|
@ -113,7 +113,7 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.selects
|
||||
expressions = scope.expression.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
|
||||
for e in expressions
|
||||
|
|
|
@ -12,7 +12,12 @@ def isolate_table_selects(expression, schema=None):
|
|||
continue
|
||||
|
||||
for _, source in scope.selected_sources.values():
|
||||
if not isinstance(source, exp.Table) or not schema.column_names(source):
|
||||
if (
|
||||
not isinstance(source, exp.Table)
|
||||
or not schema.column_names(source)
|
||||
or isinstance(source.parent, exp.Subquery)
|
||||
or isinstance(source.parent.parent, exp.Table)
|
||||
):
|
||||
continue
|
||||
|
||||
if not source.alias:
|
||||
|
|
|
@ -107,6 +107,7 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
_merge_order(outer_scope, inner_scope)
|
||||
_merge_hints(outer_scope, inner_scope)
|
||||
outer_scope.clear_cache()
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -166,7 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
if not inner_from:
|
||||
return False
|
||||
inner_from_table = inner_from.alias_or_name
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
|
||||
inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects}
|
||||
return any(
|
||||
col.table != inner_from_table
|
||||
for selection in selections
|
||||
|
|
|
@ -59,7 +59,7 @@ def reorder_joins(expression):
|
|||
dag = {name: other_table_names(join) for name, join in joins.items()}
|
||||
parent.set(
|
||||
"joins",
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
|
||||
)
|
||||
return expression
|
||||
|
||||
|
|
|
@ -42,7 +42,10 @@ def pushdown_predicates(expression):
|
|||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
if name in scope.selected_sources:
|
||||
pushdown(
|
||||
join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -48,12 +48,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
if any(select.is_star for select in right.selects):
|
||||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.selects):
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
referenced_columns[right] = [
|
||||
right.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.selects)
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
|
@ -90,7 +90,7 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
removed = False
|
||||
star = False
|
||||
|
||||
for selection in scope.selects:
|
||||
for selection in scope.expression.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
|
|
|
@ -192,13 +192,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if table and (not alias_expr or double_agg):
|
||||
column.set("table", table)
|
||||
elif not column.table and alias_expr and not double_agg:
|
||||
if isinstance(alias_expr, exp.Literal):
|
||||
if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
|
||||
if literal_index:
|
||||
column.replace(exp.Literal.number(i))
|
||||
else:
|
||||
column.replace(alias_expr.copy())
|
||||
|
||||
for i, projection in enumerate(scope.selects):
|
||||
for i, projection in enumerate(scope.expression.selects):
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
|
@ -239,7 +239,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
|
|||
ordered.set("this", new_expression)
|
||||
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
|
||||
|
||||
for ordered in ordereds:
|
||||
ordered = ordered.this
|
||||
|
@ -270,7 +270,7 @@ def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t
|
|||
|
||||
def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
||||
try:
|
||||
return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
|
||||
return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
|
||||
except IndexError:
|
||||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
|
||||
|
@ -347,7 +347,7 @@ def _expand_stars(
|
|||
if not pivot_output_columns:
|
||||
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
|
||||
|
||||
for expression in scope.selects:
|
||||
for expression in scope.expression.selects:
|
||||
if isinstance(expression, exp.Star):
|
||||
tables = list(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
|
@ -446,7 +446,7 @@ def _qualify_outputs(scope: Scope):
|
|||
new_selections = []
|
||||
|
||||
for i, (selection, aliased_column) in enumerate(
|
||||
itertools.zip_longest(scope.selects, scope.outer_column_list)
|
||||
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
|
||||
):
|
||||
if isinstance(selection, exp.Subquery):
|
||||
if not selection.output_name:
|
||||
|
|
|
@ -15,7 +15,8 @@ def qualify_tables(
|
|||
schema: t.Optional[Schema] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified, unnested tables.
|
||||
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
|
||||
(t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
|
||||
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
|
@ -23,18 +24,9 @@ def qualify_tables(
|
|||
>>> qualify_tables(expression, db="db").sql()
|
||||
'SELECT 1 FROM db.tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl)")
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2'
|
||||
|
||||
Note:
|
||||
This rule effectively enforces a left-to-right join order, since all joins
|
||||
are unnested. This means that the optimizer doesn't necessarily preserve the
|
||||
original join order, e.g. when parentheses are used to specify it explicitly.
|
||||
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
|
||||
|
||||
Args:
|
||||
expression: Expression to qualify
|
||||
|
@ -49,6 +41,13 @@ def qualify_tables(
|
|||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
unnested = derived_table.unnest()
|
||||
if isinstance(unnested, exp.Table):
|
||||
joins = unnested.args.pop("joins", None)
|
||||
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
||||
derived_table.this.set("joins", joins)
|
||||
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = next_alias_name()
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
@ -66,19 +65,9 @@ def qualify_tables(
|
|||
if not source.args.get("catalog"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
# Unnest joins attached in tables by appending them to the closest query
|
||||
for join in source.args.get("joins") or []:
|
||||
scope.expression.append("joins", join)
|
||||
|
||||
source.set("joins", None)
|
||||
source.set("wrapped", None)
|
||||
|
||||
if not source.alias:
|
||||
source = source.replace(
|
||||
alias(
|
||||
source, name or source.name or next_alias_name(), copy=True, table=True
|
||||
)
|
||||
)
|
||||
# Mutates the source by attaching an alias to it
|
||||
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
|
||||
|
||||
pivots = source.args.get("pivots")
|
||||
if pivots and not pivots[0].alias:
|
||||
|
|
|
@ -122,7 +122,11 @@ class Scope:
|
|||
self._udtfs.append(node)
|
||||
elif isinstance(node, exp.CTE):
|
||||
self._ctes.append(node)
|
||||
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
|
||||
elif (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
):
|
||||
self._derived_tables.append(node)
|
||||
elif isinstance(node, exp.Subqueryable):
|
||||
self._subqueries.append(node)
|
||||
|
@ -274,6 +278,7 @@ class Scope:
|
|||
not ancestor
|
||||
or column.table
|
||||
or isinstance(ancestor, exp.Select)
|
||||
or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
|
||||
or (
|
||||
isinstance(ancestor, exp.Order)
|
||||
and (
|
||||
|
@ -340,23 +345,6 @@ class Scope:
|
|||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def selects(self):
|
||||
"""
|
||||
Select expressions of this scope.
|
||||
|
||||
For example, for the following expression:
|
||||
SELECT 1 as a, 2 as b FROM x
|
||||
|
||||
The outputs are the "1 as a" and "2 as b" expressions.
|
||||
|
||||
Returns:
|
||||
list[exp.Expression]: expressions
|
||||
"""
|
||||
if isinstance(self.expression, exp.Union):
|
||||
return self.expression.unnest().selects
|
||||
return self.expression.selects
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
|
@ -548,6 +536,8 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
|
@ -620,6 +610,15 @@ def _traverse_ctes(scope):
|
|||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _is_subquery_scope(expression: exp.Subquery) -> bool:
|
||||
"""
|
||||
We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope.
|
||||
If an alias is present, it shadows all names under the Subquery, so that's an
|
||||
exception to this rule.
|
||||
"""
|
||||
return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias)
|
||||
|
||||
|
||||
def _traverse_tables(scope):
|
||||
sources = {}
|
||||
|
||||
|
@ -629,9 +628,8 @@ def _traverse_tables(scope):
|
|||
if from_:
|
||||
expressions.append(from_.this)
|
||||
|
||||
for expression in (scope.expression, *scope.find_all(exp.Table)):
|
||||
for join in expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
if isinstance(scope.expression, exp.Table):
|
||||
expressions.append(scope.expression)
|
||||
|
@ -655,6 +653,8 @@ def _traverse_tables(scope):
|
|||
sources[find_new_name(sources, table_name)] = expression
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
continue
|
||||
|
||||
if not isinstance(expression, exp.DerivedTable):
|
||||
|
@ -664,10 +664,15 @@ def _traverse_tables(scope):
|
|||
lateral_sources = sources
|
||||
scope_type = ScopeType.UDTF
|
||||
scopes = scope.udtf_scopes
|
||||
else:
|
||||
elif _is_subquery_scope(expression):
|
||||
lateral_sources = None
|
||||
scope_type = ScopeType.DERIVED_TABLE
|
||||
scopes = scope.derived_table_scopes
|
||||
else:
|
||||
# Makes sure we check for possible sources in nested table constructs
|
||||
expressions.append(expression.this)
|
||||
expressions.extend(join.this for join in expression.args.get("joins") or [])
|
||||
continue
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
|
@ -728,7 +733,11 @@ def walk_in_scope(expression, bfs=True):
|
|||
continue
|
||||
if (
|
||||
isinstance(node, exp.CTE)
|
||||
or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
|
||||
or (
|
||||
isinstance(node, exp.Subquery)
|
||||
and isinstance(parent, (exp.From, exp.Join))
|
||||
and _is_subquery_scope(node)
|
||||
)
|
||||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue