Merging upstream version 11.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fdac67ef7f
commit
ba0f3f0bfa
112 changed files with 126100 additions and 230 deletions
|
@ -255,12 +255,23 @@ class TypeAnnotator:
|
|||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.Values):
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
source.expression.expressions[0].expressions,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
|
@ -272,7 +283,7 @@ class TypeAnnotator:
|
|||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
elif source:
|
||||
elif source and col.table in selects:
|
||||
col.type = selects[col.table][col.name].type
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
|
34
sqlglot/optimizer/expand_laterals.py
Normal file
34
sqlglot/optimizer/expand_laterals.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
|
||||
|
||||
def expand_laterals(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Expand lateral column alias references.
|
||||
|
||||
This assumes `qualify_columns` as already run.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> expand_laterals(expression).sql()
|
||||
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
|
||||
|
||||
Args:
|
||||
expression: expression to optimize
|
||||
Returns:
|
||||
optimized expression
|
||||
"""
|
||||
for select in expression.find_all(exp.Select):
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
for projection in select.expressions:
|
||||
for column in projection.find_all(exp.Column):
|
||||
if not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
return expression
|
|
@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
|
@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
|
|||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
@ -22,6 +23,8 @@ RULES = (
|
|||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
expand_laterals,
|
||||
validate_qualify_columns,
|
||||
pushdown_projections,
|
||||
normalize,
|
||||
unnest_subqueries,
|
||||
|
|
|
@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
SELECT_ALL = object()
|
||||
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = alias("1", "_")
|
||||
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
|
@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
if removed:
|
||||
|
@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
|
|||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
|
|||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver)
|
||||
_qualify_outputs(scope)
|
||||
_check_unknown_tables(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def validate_qualify_columns(expression):
|
||||
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
||||
unqualified_columns = []
|
||||
for scope in traverse_scope(expression):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
unqualified_columns.extend(scope.unqualified_columns)
|
||||
if scope.external_columns and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
|
||||
|
||||
if unqualified_columns:
|
||||
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||
return expression
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables):
|
||||
"""
|
||||
Remove table column aliases.
|
||||
|
@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
|
|||
if not column_table:
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if not scope.is_subquery and not scope.is_udtf:
|
||||
if column_table is None:
|
||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
||||
|
||||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver):
|
|||
for column in columns_missing_from_scope:
|
||||
column_table = resolver.get_table(column.name)
|
||||
|
||||
if column_table is None:
|
||||
raise OptimizeError(f"Ambiguous column: {column.name}")
|
||||
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
if column_table:
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
|
@ -322,11 +329,6 @@ def _qualify_outputs(scope):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _check_unknown_tables(scope):
|
||||
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
|
||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
||||
|
||||
|
||||
class _Resolver:
|
||||
"""
|
||||
Helper for resolving columns.
|
||||
|
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import csv_reader
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
|
||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||
|
@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
next_name = lambda: f"_q_{next(sequence)}"
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in scope.ctes + scope.derived_tables:
|
||||
if not derived_table.args.get("alias"):
|
||||
|
@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
source = source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
source.this if identifier else f"_q_{next(sequence)}",
|
||||
source.this if identifier else next_name(),
|
||||
table=True,
|
||||
)
|
||||
)
|
||||
|
@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
schema.add_table(
|
||||
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
|
||||
udtf.set("alias", table_alias)
|
||||
|
||||
if not table_alias.name:
|
||||
table_alias.set("this", next_name())
|
||||
|
||||
return expression
|
||||
|
|
|
@ -237,6 +237,8 @@ class Scope:
|
|||
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
|
||||
if (
|
||||
not ancestor
|
||||
# Window functions can have an ORDER BY clause
|
||||
or not isinstance(ancestor.parent, exp.Select)
|
||||
or column.table
|
||||
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
||||
):
|
||||
|
@ -479,7 +481,7 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
_set_udtf_scope(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
else:
|
||||
|
@ -509,6 +511,22 @@ def _traverse_union(scope):
|
|||
scope.union_scopes = [left, right]
|
||||
|
||||
|
||||
def _set_udtf_scope(scope):
|
||||
parent = scope.expression.parent
|
||||
from_ = parent.args.get("from")
|
||||
|
||||
if not from_:
|
||||
return
|
||||
|
||||
for table in from_.expressions:
|
||||
if isinstance(table, exp.Table):
|
||||
scope.tables.append(table)
|
||||
elif isinstance(table, exp.Subquery):
|
||||
scope.subqueries.append(table)
|
||||
_add_table_sources(scope)
|
||||
_traverse_subqueries(scope)
|
||||
|
||||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
is_cte = scope_type == ScopeType.CTE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue