1
0
Fork 0

Merging upstream version 26.3.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 22:00:29 +01:00
parent 34733e7b48
commit c16ed2270a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
89 changed files with 59179 additions and 57645 deletions

View file

@ -56,7 +56,21 @@ def qualify_columns(
dialect = Dialect.get_or_raise(schema.dialect)
pseudocolumns = dialect.PSEUDOCOLUMNS
snowflake_or_oracle = dialect in ("oracle", "snowflake")
for scope in traverse_scope(expression):
scope_expression = scope.expression
is_select = isinstance(scope_expression, exp.Select)
if is_select and snowflake_or_oracle and scope_expression.args.get("connect"):
# In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL
# pseudocolumn, which doesn't belong to a table, so we change it into an identifier
scope_expression.transform(
lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n,
copy=False,
)
scope.clear_cache()
resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
@ -76,7 +90,7 @@ def qualify_columns(
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver, dialect)
if isinstance(scope.expression, exp.Select):
if is_select:
if expand_stars:
_expand_stars(
scope,
@ -159,6 +173,9 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
if names and not ordered:
raise OptimizeError(f"Joins {names} missing source table {scope.expression}")
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
@ -180,6 +197,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
is_semi_or_anti_join = join.is_semi_or_anti_join
for identifier in using:
identifier = identifier.name
@ -208,10 +226,14 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
if table not in tables:
tables[table] = None
if join_table not in tables:
tables[join_table] = None
# Do not update the dict if this was a SEMI/ANTI join in
# order to avoid generating COALESCE columns for this join pair
if not is_semi_or_anti_join:
if table not in tables:
tables[table] = None
if join_table not in tables:
tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions, copy=False))
@ -898,22 +920,10 @@ class Resolver:
for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
pseudocolumns = self._get_source_pseudocolumns(name)
if pseudocolumns:
columns = list(columns)
columns.extend(c for c in pseudocolumns if c not in columns)
self._get_source_columns_cache[cache_key] = columns
return self._get_source_columns_cache[cache_key]
def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
# When there is a CONNECT BY clause, there is only one table being scanned
# See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
return ["LEVEL"]
return []
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {

View file

@ -100,6 +100,7 @@ class Scope:
self._join_hints = None
self._pivots = None
self._references = None
self._semi_anti_join_tables = None
def branch(
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
@ -126,6 +127,7 @@ class Scope:
self._raw_columns = []
self._stars = []
self._join_hints = []
self._semi_anti_join_tables = set()
for node in self.walk(bfs=False):
if node is self.expression:
@ -139,6 +141,10 @@ class Scope:
else:
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
parent = node.parent
if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
self._semi_anti_join_tables.add(node.alias_or_name)
self._tables.append(node)
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
@ -311,6 +317,11 @@ class Scope:
result = {}
for name, node in self.references:
if name in self._semi_anti_join_tables:
# The RHS table of SEMI/ANTI joins shouldn't be collected as a
# selected source
continue
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
@ -351,7 +362,10 @@ class Scope:
self._external_columns = left.external_columns + right.external_columns
else:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
c
for c in self.columns
if c.table not in self.selected_sources
and c.table not in self.semi_or_anti_join_tables
]
return self._external_columns
@ -387,6 +401,10 @@ class Scope:
return self._pivots
@property
def semi_or_anti_join_tables(self):
return self._semi_anti_join_tables or set()
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.

View file

@ -749,7 +749,7 @@ def simplify_parens(expression):
if (
not isinstance(this, exp.Select)
and not isinstance(parent, exp.SubqueryPredicate)
and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)