Merging upstream version 17.9.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2bf6699c56
commit
9777880e00
87 changed files with 45907 additions and 42511 deletions
|
@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
|
|||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
# This ensures we don't drop the "pivot" arg from a pivoted subquery
|
||||
if scope.parent.pivots:
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
# - eliminate a lateral correlated subquery
|
||||
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
||||
return None
|
||||
|
||||
parent = scope.expression.parent
|
||||
|
|
|
@ -1,8 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
|
||||
...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
|
||||
>>> normalize_identifiers(expression).sql()
|
||||
'SELECT bar.a AS a FROM "Foo".bar'
|
||||
>>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
|
||||
'FOO'
|
||||
|
||||
Args:
|
||||
expression: The expression to transform.
|
||||
|
@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
expression = exp.maybe_parse(expression, dialect=dialect)
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
|
|
@ -39,6 +39,7 @@ def qualify_columns(
|
|||
"""
|
||||
schema = ensure_schema(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
|
@ -55,7 +56,7 @@ def qualify_columns(
|
|||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
|
||||
def _expand_stars(
|
||||
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
|
||||
scope: Scope,
|
||||
resolver: Resolver,
|
||||
using_column_tables: t.Dict[str, t.Any],
|
||||
pseudocolumns: t.Set[str],
|
||||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
|
@ -367,14 +371,8 @@ def _expand_stars(
|
|||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
|
||||
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
|
||||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
if resolver.schema.dialect == "bigquery":
|
||||
columns = [
|
||||
name
|
||||
for name in columns
|
||||
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
|
||||
]
|
||||
if pseudocolumns:
|
||||
columns = [name for name in columns if name.upper() not in pseudocolumns]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
|
|
|
@ -80,7 +80,9 @@ def qualify_tables(
|
|||
header = next(reader)
|
||||
columns = next(reader)
|
||||
schema.add_table(
|
||||
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||
source,
|
||||
{k: type(v).__name__ for k, v in zip(header, columns)},
|
||||
match_depth=False,
|
||||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
|
|
|
@ -435,7 +435,10 @@ class Scope:
|
|||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(self.is_subquery and self.external_columns)
|
||||
return bool(
|
||||
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
|
||||
and self.external_columns
|
||||
)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
|
@ -486,7 +489,7 @@ class Scope:
|
|||
|
||||
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
||||
"""
|
||||
Traverse an expression by it's "scopes".
|
||||
Traverse an expression by its "scopes".
|
||||
|
||||
"Scope" represents the current context of a Select statement.
|
||||
|
||||
|
@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
list[Scope]: scope instances
|
||||
"""
|
||||
if not isinstance(expression, exp.Unionable):
|
||||
return []
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
if isinstance(expression, exp.Unionable) or (
|
||||
isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
|
||||
):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
||||
|
@ -539,7 +545,9 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(scope.expression, exp.DDL):
|
||||
yield from _traverse_ddl(scope)
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
|
@ -576,10 +584,10 @@ def _traverse_ctes(scope):
|
|||
for cte in scope.ctes:
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if scope.expression.args["with"].recursive:
|
||||
# if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
|
||||
# thus the recursive scope is the first section of the union.
|
||||
with_ = scope.expression.args.get("with")
|
||||
if with_ and with_.recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
|
@ -692,8 +700,7 @@ def _traverse_tables(scope):
|
|||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
alias = expression.alias
|
||||
sources[alias] = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
# append the final child_scope yielded
|
||||
scopes.append(child_scope)
|
||||
|
@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
|
|||
scope.subquery_scopes.append(top)
|
||||
|
||||
|
||||
def _traverse_udtfs(scope):
|
||||
if isinstance(scope.expression, exp.Unnest):
|
||||
expressions = scope.expression.expressions
|
||||
elif isinstance(scope.expression, exp.Lateral):
|
||||
expressions = [scope.expression.this]
|
||||
else:
|
||||
expressions = []
|
||||
|
||||
sources = {}
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
|
||||
top = None
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
outer_column_list=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.table_scopes.append(top)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
def _traverse_ddl(scope):
|
||||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
||||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
|
|
|
@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
|
|||
if not predicate or parent_select is not predicate.parent_select:
|
||||
return
|
||||
|
||||
# this subquery returns a scalar and can just be converted to a cross join
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
if not isinstance(predicate, (exp.In, exp.Any)):
|
||||
having = predicate.find_ancestor(exp.Having)
|
||||
column = exp.column(select.selects[0].alias_or_name, alias)
|
||||
if having and having.parent_select is parent_select:
|
||||
column = exp.Max(this=column)
|
||||
_replace(select.parent, column)
|
||||
|
||||
parent_select.join(
|
||||
select,
|
||||
join_type="CROSS",
|
||||
join_alias=alias,
|
||||
copy=False,
|
||||
)
|
||||
clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
|
||||
clause_parent_select = clause.parent_select if clause else None
|
||||
|
||||
if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
|
||||
(not clause or clause_parent_select is not parent_select)
|
||||
and (
|
||||
parent_select.args.get("group")
|
||||
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
|
||||
)
|
||||
):
|
||||
column = exp.Max(this=column)
|
||||
|
||||
_replace(select.parent, column)
|
||||
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
|
||||
return
|
||||
|
||||
if select.find(exp.Limit, exp.Offset):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue