1
0
Fork 0

Merging upstream version 11.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:23:26 +01:00
parent fdac67ef7f
commit ba0f3f0bfa
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
112 changed files with 126100 additions and 230 deletions

View file

@ -255,12 +255,23 @@ class TypeAnnotator:
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.Values):
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
source.expression.expressions[0].expressions,
values,
)
}
else:
@ -272,7 +283,7 @@ class TypeAnnotator:
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source:
elif source and col.table in selects:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)

View file

@ -0,0 +1,34 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
def expand_laterals(expression: exp.Expression) -> exp.Expression:
"""
Expand lateral column alias references.
This assumes `qualify_columns` as already run.
Example:
>>> import sqlglot
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
>>> expression = sqlglot.parse_one(sql)
>>> expand_laterals(expression).sql()
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
Args:
expression: expression to optimize
Returns:
optimized expression
"""
for select in expression.find_all(exp.Select):
alias_to_expression: t.Dict[str, exp.Expression] = {}
for projection in select.expressions:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@ -22,6 +23,8 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
expand_laterals,
validate_qualify_columns,
pushdown_projections,
normalize,
unnest_subqueries,

View file

@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
SELECT_ALL = object()
# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression):
@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
if removed:
@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)

View file

@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
return expression
def validate_qualify_columns(expression):
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver):
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column.name}")
column.set("table", exp.to_identifier(column_table))
if column_table:
column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
@ -322,11 +329,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
class _Resolver:
"""
Helper for resolving columns.

View file

@ -2,7 +2,7 @@ import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
sequence = itertools.count()
next_name = lambda: f"_q_{next(sequence)}"
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
source.this if identifier else next_name(),
table=True,
)
)
@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_name())
return expression

View file

@ -237,6 +237,8 @@ class Scope:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if (
not ancestor
# Window functions can have an ORDER BY clause
or not isinstance(ancestor.parent, exp.Select)
or column.table
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
@ -479,7 +481,7 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
_set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
@ -509,6 +511,22 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
def _set_udtf_scope(scope):
parent = scope.expression.parent
from_ = parent.args.get("from")
if not from_:
return
for table in from_.expressions:
if isinstance(table, exp.Table):
scope.tables.append(table)
elif isinstance(table, exp.Subquery):
scope.subqueries.append(table)
_add_table_sources(scope)
_traverse_subqueries(scope)
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE