Merging upstream version 26.3.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
34733e7b48
commit
c16ed2270a
89 changed files with 59179 additions and 57645 deletions
|
@ -56,7 +56,21 @@ def qualify_columns(
|
|||
dialect = Dialect.get_or_raise(schema.dialect)
|
||||
pseudocolumns = dialect.PSEUDOCOLUMNS
|
||||
|
||||
snowflake_or_oracle = dialect in ("oracle", "snowflake")
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
scope_expression = scope.expression
|
||||
is_select = isinstance(scope_expression, exp.Select)
|
||||
|
||||
if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
|
||||
# In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
|
||||
# pseudocolumn, which doesn't belong to a table, so we change it into an identifier
|
||||
scope_expression.transform(
|
||||
lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
|
||||
copy=False,
|
||||
)
|
||||
scope.clear_cache()
|
||||
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
_pop_table_column_aliases(scope.derived_tables)
|
||||
|
@ -76,7 +90,7 @@ def qualify_columns(
|
|||
if not schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver, dialect)
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if is_select:
|
||||
if expand_stars:
|
||||
_expand_stars(
|
||||
scope,
|
||||
|
@ -159,6 +173,9 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
names = {join.alias_or_name for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
if names and not ordered:
|
||||
raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
|
||||
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
||||
|
||||
|
@ -180,6 +197,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
join_columns = resolver.get_source_columns(join_table)
|
||||
conditions = []
|
||||
using_identifier_count = len(using)
|
||||
is_semi_or_anti_join = join.is_semi_or_anti_join
|
||||
|
||||
for identifier in using:
|
||||
identifier = identifier.name
|
||||
|
@ -208,10 +226,14 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
|
||||
# Set all values in the dict to None, because we only care about the key ordering
|
||||
tables = column_tables.setdefault(identifier, {})
|
||||
if table not in tables:
|
||||
tables[table] = None
|
||||
if join_table not in tables:
|
||||
tables[join_table] = None
|
||||
|
||||
# Do not update the dict if this was a SEMI/ANTI join in
|
||||
# order to avoid generating COALESCE columns for this join pair
|
||||
if not is_semi_or_anti_join:
|
||||
if table not in tables:
|
||||
tables[table] = None
|
||||
if join_table not in tables:
|
||||
tables[join_table] = None
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions, copy=False))
|
||||
|
@ -898,22 +920,10 @@ class Resolver:
|
|||
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
|
||||
pseudocolumns = self._get_source_pseudocolumns(name)
|
||||
if pseudocolumns:
|
||||
columns = list(columns)
|
||||
columns.extend(c for c in pseudocolumns if c not in columns)
|
||||
|
||||
self._get_source_columns_cache[cache_key] = columns
|
||||
|
||||
return self._get_source_columns_cache[cache_key]
|
||||
|
||||
def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
|
||||
if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
|
||||
# When there is a CONNECT BY clause, there is only one table being scanned
|
||||
# See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
|
||||
return ["LEVEL"]
|
||||
return []
|
||||
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
||||
if self._source_columns is None:
|
||||
self._source_columns = {
|
||||
|
|
|
@ -100,6 +100,7 @@ class Scope:
|
|||
self._join_hints = None
|
||||
self._pivots = None
|
||||
self._references = None
|
||||
self._semi_anti_join_tables = None
|
||||
|
||||
def branch(
|
||||
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
|
||||
|
@ -126,6 +127,7 @@ class Scope:
|
|||
self._raw_columns = []
|
||||
self._stars = []
|
||||
self._join_hints = []
|
||||
self._semi_anti_join_tables = set()
|
||||
|
||||
for node in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
|
@ -139,6 +141,10 @@ class Scope:
|
|||
else:
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
parent = node.parent
|
||||
if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
|
||||
self._semi_anti_join_tables.add(node.alias_or_name)
|
||||
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, exp.JoinHint):
|
||||
self._join_hints.append(node)
|
||||
|
@ -311,6 +317,11 @@ class Scope:
|
|||
result = {}
|
||||
|
||||
for name, node in self.references:
|
||||
if name in self._semi_anti_join_tables:
|
||||
# The RHS table of SEMI/ANTI joins shouldn't be collected as a
|
||||
# selected source
|
||||
continue
|
||||
|
||||
if name in result:
|
||||
raise OptimizeError(f"Alias already used: {name}")
|
||||
if name in self.sources:
|
||||
|
@ -351,7 +362,10 @@ class Scope:
|
|||
self._external_columns = left.external_columns + right.external_columns
|
||||
else:
|
||||
self._external_columns = [
|
||||
c for c in self.columns if c.table not in self.selected_sources
|
||||
c
|
||||
for c in self.columns
|
||||
if c.table not in self.selected_sources
|
||||
and c.table not in self.semi_or_anti_join_tables
|
||||
]
|
||||
|
||||
return self._external_columns
|
||||
|
@ -387,6 +401,10 @@ class Scope:
|
|||
|
||||
return self._pivots
|
||||
|
||||
@property
|
||||
def semi_or_anti_join_tables(self):
|
||||
return self._semi_anti_join_tables or set()
|
||||
|
||||
def source_columns(self, source_name):
|
||||
"""
|
||||
Get all columns in the current scope for a particular source.
|
||||
|
|
|
@ -749,7 +749,7 @@ def simplify_parens(expression):
|
|||
|
||||
if (
|
||||
not isinstance(this, exp.Select)
|
||||
and not isinstance(parent, exp.SubqueryPredicate)
|
||||
and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
|
||||
and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue