Merging upstream version 17.11.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2bd548fc43
commit
14ca349bca
69 changed files with 30974 additions and 30030 deletions
|
@ -41,5 +41,6 @@ def normalize_identifiers(expression, dialect=None):
|
|||
Returns:
|
||||
The transformed expression.
|
||||
"""
|
||||
expression = exp.maybe_parse(expression, dialect=dialect)
|
||||
if isinstance(expression, str):
|
||||
expression = exp.to_identifier(expression)
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
|
|
@ -31,6 +31,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
schema = ensure_schema(schema)
|
||||
source_column_alias_count = {}
|
||||
referenced_columns = defaultdict(set)
|
||||
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
|
@ -38,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
# columns for a particular scope are completely build by the time we get to it.
|
||||
for scope in reversed(traverse_scope(expression)):
|
||||
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
|
||||
alias_count = source_column_alias_count.get(scope, 0)
|
||||
|
||||
if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
|
||||
if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
|
||||
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
|
||||
# we select from a pivoted source in the parent scope.
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
@ -59,7 +61,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if remove_unused_selections:
|
||||
_remove_unused_selections(scope, parent_selections, schema)
|
||||
_remove_unused_selections(scope, parent_selections, schema, alias_count)
|
||||
|
||||
if scope.expression.is_star:
|
||||
continue
|
||||
|
@ -72,15 +74,19 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
selects[table_name].add(col_name)
|
||||
|
||||
# Push the selected columns down to the next scope
|
||||
for name, (_, source) in scope.selected_sources.items():
|
||||
for name, (node, source) in scope.selected_sources.items():
|
||||
if isinstance(source, Scope):
|
||||
columns = selects.get(name) or set()
|
||||
referenced_columns[source].update(columns)
|
||||
|
||||
column_aliases = node.alias_column_names
|
||||
if column_aliases:
|
||||
source_column_alias_count[source] = len(column_aliases)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections, schema):
|
||||
def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -93,11 +99,14 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
removed = False
|
||||
star = False
|
||||
|
||||
select_all = SELECT_ALL in parent_selections
|
||||
|
||||
for selection in scope.expression.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
if select_all or name in parent_selections or name in order_refs or alias_count > 0:
|
||||
new_selections.append(selection)
|
||||
alias_count -= 1
|
||||
else:
|
||||
if selection.is_star:
|
||||
star = True
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import Dialect, DialectType
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
|
||||
|
@ -58,6 +59,7 @@ def qualify_columns(
|
|||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
|
@ -85,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
|
|||
"""
|
||||
Remove table column aliases.
|
||||
|
||||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
table_alias = derived_table.args.get("alias")
|
||||
|
@ -111,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
|
||||
columns = {}
|
||||
|
||||
for k in scope.selected_sources:
|
||||
if k in ordered:
|
||||
for column in resolver.get_source_columns(k):
|
||||
if column not in columns:
|
||||
columns[column] = k
|
||||
for source_name in scope.selected_sources:
|
||||
if source_name in ordered:
|
||||
for column_name in resolver.get_source_columns(source_name):
|
||||
if column_name not in columns:
|
||||
columns[column_name] = source_name
|
||||
|
||||
source_table = ordered[-1]
|
||||
ordered.append(join_table)
|
||||
|
@ -183,6 +185,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
for column, *_ in walk_in_scope(node):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
|
@ -198,7 +201,10 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
if literal_index:
|
||||
column.replace(exp.Literal.number(i))
|
||||
else:
|
||||
column.replace(alias_expr.copy())
|
||||
column = column.replace(exp.paren(alias_expr))
|
||||
simplified = simplify_parens(column)
|
||||
if simplified is not column:
|
||||
column.replace(simplified)
|
||||
|
||||
for i, projection in enumerate(scope.expression.selects):
|
||||
replace_columns(projection)
|
||||
|
@ -213,7 +219,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
scope.clear_cache()
|
||||
|
||||
|
||||
def _expand_group_by(scope: Scope):
|
||||
def _expand_group_by(scope: Scope) -> None:
|
||||
expression = scope.expression
|
||||
group = expression.args.get("group")
|
||||
if not group:
|
||||
|
@ -223,7 +229,7 @@ def _expand_group_by(scope: Scope):
|
|||
expression.set("group", group)
|
||||
|
||||
|
||||
def _expand_order_by(scope: Scope, resolver: Resolver):
|
||||
def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
|
||||
order = scope.expression.args.get("order")
|
||||
if not order:
|
||||
return
|
||||
|
@ -442,7 +448,7 @@ def _add_replace_columns(
|
|||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope: Scope):
|
||||
def _qualify_outputs(scope: Scope) -> None:
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
|
@ -482,9 +488,9 @@ class Resolver:
|
|||
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns = None
|
||||
self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
|
||||
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
|
||||
self._all_columns = None
|
||||
self._all_columns: t.Optional[t.Set[str]] = None
|
||||
self._infer_schema = infer_schema
|
||||
|
||||
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
||||
|
@ -528,7 +534,7 @@ class Resolver:
|
|||
return exp.to_identifier(table_name)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
def all_columns(self) -> t.Set[str]:
|
||||
"""All available columns of all sources in this scope"""
|
||||
if self._all_columns is None:
|
||||
self._all_columns = {
|
||||
|
@ -536,53 +542,67 @@ class Resolver:
|
|||
}
|
||||
return self._all_columns
|
||||
|
||||
def get_source_columns(self, name, only_visible=False):
|
||||
"""Resolve the source columns for a given source `name`"""
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
|
||||
"""Resolve the source columns for a given source `name`."""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
||||
source = self.scope.sources[name]
|
||||
|
||||
# If referencing a table, return the columns from the schema
|
||||
if isinstance(source, exp.Table):
|
||||
return self.schema.column_names(source, only_visible)
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
columns = source.expression.alias_column_names
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
return source.expression.alias_column_names
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
column_aliases = node.expression.alias_column_names
|
||||
elif isinstance(node, exp.Expression):
|
||||
column_aliases = node.alias_column_names
|
||||
else:
|
||||
column_aliases = []
|
||||
|
||||
# Otherwise, if referencing another scope, return that scope's named selects
|
||||
return source.expression.named_selects
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names
|
||||
return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
|
||||
|
||||
def _get_all_source_columns(self):
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
k: self.get_source_columns(k)
|
||||
for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
|
||||
source_name: self.get_source_columns(source_name)
|
||||
for source_name, source in itertools.chain(
|
||||
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
|
||||
)
|
||||
}
|
||||
return self._source_columns
|
||||
|
||||
def _get_unambiguous_columns(self, source_columns):
|
||||
def _get_unambiguous_columns(
|
||||
self, source_columns: t.Dict[str, t.List[str]]
|
||||
) -> t.Dict[str, str]:
|
||||
"""
|
||||
Find all the unambiguous columns in sources.
|
||||
|
||||
Args:
|
||||
source_columns (dict): Mapping of names to source columns
|
||||
source_columns: Mapping of names to source columns.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of column name to source name
|
||||
Mapping of column name to source name.
|
||||
"""
|
||||
if not source_columns:
|
||||
return {}
|
||||
|
||||
source_columns = list(source_columns.items())
|
||||
source_columns_pairs = list(source_columns.items())
|
||||
|
||||
first_table, first_columns = source_columns[0]
|
||||
first_table, first_columns = source_columns_pairs[0]
|
||||
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
|
||||
all_columns = set(unambiguous_columns)
|
||||
|
||||
for table, columns in source_columns[1:]:
|
||||
for table, columns in source_columns_pairs[1:]:
|
||||
unique = self._find_unique_columns(columns)
|
||||
ambiguous = set(all_columns).intersection(unique)
|
||||
all_columns.update(columns)
|
||||
|
||||
for column in ambiguous:
|
||||
unambiguous_columns.pop(column, None)
|
||||
for column in unique.difference(ambiguous):
|
||||
|
@ -591,7 +611,7 @@ class Resolver:
|
|||
return unambiguous_columns
|
||||
|
||||
@staticmethod
|
||||
def _find_unique_columns(columns):
|
||||
def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
|
||||
"""
|
||||
Find the unique columns in a list of columns.
|
||||
|
||||
|
@ -601,7 +621,7 @@ class Resolver:
|
|||
|
||||
This is necessary because duplicate column names are ambiguous.
|
||||
"""
|
||||
counts = {}
|
||||
counts: t.Dict[str, int] = {}
|
||||
for column in columns:
|
||||
counts[column] = counts.get(column, 0) + 1
|
||||
return {column for column, count in counts.items() if count == 1}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue