Merging upstream version 16.7.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
331a760a3d
commit
088f137198
75 changed files with 33866 additions and 31988 deletions
|
@ -60,8 +60,8 @@ def qualify(
|
|||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -56,13 +56,13 @@ def qualify_columns(
|
|||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def validate_qualify_columns(expression):
|
||||
def validate_qualify_columns(expression: E) -> E:
|
||||
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
||||
unqualified_columns = []
|
||||
for scope in traverse_scope(expression):
|
||||
|
@ -79,7 +79,7 @@ def validate_qualify_columns(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables):
|
||||
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
|
||||
"""
|
||||
Remove table column aliases.
|
||||
|
||||
|
@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables):
|
|||
table_alias.args.pop("columns", None)
|
||||
|
||||
|
||||
def _expand_using(scope, resolver):
|
||||
def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.alias_or_name for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables = {}
|
||||
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
||||
|
||||
for join in joins:
|
||||
using = join.args.get("using")
|
||||
|
@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
|
||||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
|
||||
):
|
||||
def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
|
||||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
|
||||
if table and column.find_ancestor(exp.AggFunc):
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr = alias_to_expression.get(column.name)
|
||||
double_agg = (
|
||||
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
|
||||
if alias_expr
|
||||
else False
|
||||
)
|
||||
|
||||
if table and (not alias_expr or double_agg):
|
||||
column.set("table", table)
|
||||
elif expand and not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
elif not column.table and alias_expr and not double_agg:
|
||||
column.replace(alias_expr.copy())
|
||||
|
||||
for projection in scope.selects:
|
||||
replace_columns(projection)
|
||||
|
@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"))
|
||||
replace_columns(expression.args.get("having"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
||||
scope.clear_cache()
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
def _expand_group_by(scope: Scope):
|
||||
expression = scope.expression
|
||||
group = expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
scope.expression.set("group", group)
|
||||
expression.set("group", group)
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
groups = set(group.expressions)
|
||||
group.meta["final"] = True
|
||||
|
||||
for e in expression.selects:
|
||||
for node, *_ in e.walk():
|
||||
if node in groups:
|
||||
e.meta["final"] = True
|
||||
break
|
||||
|
||||
having = expression.args.get("having")
|
||||
if having:
|
||||
for node, *_ in having.walk():
|
||||
if node in groups:
|
||||
having.meta["final"] = True
|
||||
break
|
||||
|
||||
|
||||
def _expand_order_by(scope):
|
||||
def _expand_order_by(scope: Scope, resolver: Resolver):
|
||||
order = scope.expression.args.get("order")
|
||||
if not order:
|
||||
return
|
||||
|
@ -220,10 +244,21 @@ def _expand_order_by(scope):
|
|||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds)),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
if not col.table:
|
||||
col.set("table", resolver.get_table(col.name))
|
||||
|
||||
ordered.set("this", new_expression)
|
||||
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
|
||||
|
||||
def _expand_positional_references(scope, expressions):
|
||||
for ordered in ordereds:
|
||||
ordered.set("this", selects.get(ordered.this, ordered.this))
|
||||
|
||||
|
||||
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
|
||||
new_nodes = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
|
@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions):
|
|||
return new_nodes
|
||||
|
||||
|
||||
def _qualify_columns(scope, resolver):
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
column_table = column.table
|
||||
|
@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver):
|
|||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver, using_column_tables):
|
||||
def _expand_stars(
|
||||
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
|
||||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
except_columns = {}
|
||||
replace_columns = {}
|
||||
except_columns: t.Dict[int, t.Set[str]] = {}
|
||||
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
|
||||
pivot_columns = None
|
||||
pivot_output_columns = None
|
||||
pivot = seq_get(scope.pivots, 0)
|
||||
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
|
||||
|
||||
has_pivoted_source = pivot and not pivot.args.get("unpivot")
|
||||
if has_pivoted_source:
|
||||
if pivot and has_pivoted_source:
|
||||
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
|
||||
|
||||
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
|
||||
|
@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
|
||||
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
|
||||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
if resolver.schema.dialect == "bigquery":
|
||||
columns = [
|
||||
name
|
||||
for name in columns
|
||||
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
|
||||
]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
if has_pivoted_source:
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
|
@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _add_except_columns(expression, tables, except_columns):
|
||||
def _add_except_columns(
|
||||
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
|
||||
) -> None:
|
||||
except_ = expression.args.get("except")
|
||||
|
||||
if not except_:
|
||||
|
@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns):
|
|||
except_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _add_replace_columns(expression, tables, replace_columns):
|
||||
def _add_replace_columns(
|
||||
expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
|
||||
) -> None:
|
||||
replace = expression.args.get("replace")
|
||||
|
||||
if not replace:
|
||||
|
@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns):
|
|||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope):
|
||||
def _qualify_outputs(scope: Scope):
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
|
@ -429,7 +479,7 @@ class Resolver:
|
|||
This is a class so we can lazily load some things and easily share them across functions.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, schema, infer_schema: bool = True):
|
||||
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns = None
|
||||
|
|
|
@ -28,6 +28,8 @@ def simplify(expression):
|
|||
generate = cached_generator()
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
if expression.meta.get("final"):
|
||||
return expression
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, generate, root)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue