1
0
Fork 0

Merging upstream version 17.9.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:48:36 +01:00
parent 2bf6699c56
commit 9777880e00
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
87 changed files with 45907 additions and 42511 deletions

View file

@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
# This ensures we don't drop the "pivot" arg from a pivoted subquery
if scope.parent.pivots:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent

View file

@ -1,8 +1,23 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
...
def normalize_identifiers(expression, dialect=None):
"""
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
>>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
'FOO'
Args:
expression: The expression to transform.
@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
expression = exp.maybe_parse(expression, dialect=dialect)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)

View file

@ -39,6 +39,7 @@ def qualify_columns(
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@ -55,7 +56,7 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _expand_stars(
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
scope: Scope,
resolver: Resolver,
using_column_tables: t.Dict[str, t.Any],
pseudocolumns: t.Set[str],
) -> None:
"""Expand stars to lists of column selections"""
@ -367,14 +371,8 @@ def _expand_stars(
columns = resolver.get_source_columns(table, only_visible=True)
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
if resolver.schema.dialect == "bigquery":
columns = [
name
for name in columns
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
]
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:

View file

@ -80,7 +80,9 @@ def qualify_tables(
header = next(reader)
columns = next(reader)
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
source,
{k: type(v).__name__ for k, v in zip(header, columns)},
match_depth=False,
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression

View file

@ -435,7 +435,10 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(self.is_subquery and self.external_columns)
return bool(
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
and self.external_columns
)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
@ -486,7 +489,7 @@ class Scope:
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
if not isinstance(expression, exp.Unionable):
return []
return list(_traverse_scope(Scope(expression)))
if isinstance(expression, exp.Unionable) or (
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
):
return list(_traverse_scope(Scope(expression)))
return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
@ -539,7 +545,9 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
yield from _traverse_udtfs(scope)
elif isinstance(scope.expression, exp.DDL):
yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@ -576,10 +584,10 @@ def _traverse_ctes(scope):
for cte in scope.ctes:
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
if scope.expression.args["with"].recursive:
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
# thus the recursive scope is the first section of the union.
with_ = scope.expression.args.get("with")
if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
@ -692,8 +700,7 @@ def _traverse_tables(scope):
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
alias = expression.alias
sources[alias] = child_scope
sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
scope.subquery_scopes.append(top)
def _traverse_udtfs(scope):
if isinstance(scope.expression, exp.Unnest):
expressions = scope.expression.expressions
elif isinstance(scope.expression, exp.Lateral):
expressions = [scope.expression.this]
else:
expressions = []
sources = {}
for expression in expressions:
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
top = None
for child_scope in _traverse_scope(
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
)
):
yield child_scope
top = child_scope
sources[expression.alias] = child_scope
scope.derived_table_scopes.append(top)
scope.table_scopes.append(top)
scope.sources.update(sources)
def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes
yield from _traverse_scope(query_scope)
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at

View file

@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
if not predicate or parent_select is not predicate.parent_select:
return
# this subquery returns a scalar and can just be converted to a cross join
# This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
if having and having.parent_select is parent_select:
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(
select,
join_type="CROSS",
join_alias=alias,
copy=False,
)
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
clause_parent_select = clause.parent_select if clause else None
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
(not clause or clause_parent_select is not parent_select)
and (
parent_select.args.get("group")
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
)
):
column = exp.Max(this=column)
_replace(select.parent, column)
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
return
if select.find(exp.Limit, exp.Offset):