1
0
Fork 0

Merging upstream version 26.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:59:10 +01:00
parent e2fd836612
commit 63d24513e5
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
65 changed files with 45416 additions and 44542 deletions

View file

@ -66,6 +66,7 @@ def qualify_columns(
_expand_alias_refs(
scope,
resolver,
dialect,
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
)
@ -73,9 +74,9 @@ def qualify_columns(
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
_expand_alias_refs(scope, resolver, dialect)
if not isinstance(scope.expression, exp.UDTF):
if isinstance(scope.expression, exp.Select):
if expand_stars:
_expand_stars(
scope,
@ -236,7 +237,15 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
return column_tables
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
def _expand_alias_refs(
scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
) -> None:
"""
Expand references to aliases.
Example:
SELECT y.foo AS bar, bar * 2 AS baz FROM y
=> SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
"""
expression = scope.expression
if not isinstance(expression, exp.Select):
@ -309,6 +318,12 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
# Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
# https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
if dialect == "snowflake":
for join in expression.args.get("joins") or []:
replace_columns(join)
scope.clear_cache()
@ -883,10 +898,22 @@ class Resolver:
for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
pseudocolumns = self._get_source_pseudocolumns(name)
if pseudocolumns:
columns = list(columns)
columns.extend(c for c in pseudocolumns if c not in columns)
self._get_source_columns_cache[cache_key] = columns
return self._get_source_columns_cache[cache_key]
def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
# When there is a CONNECT BY clause, there is only one table being scanned
# See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
return ["LEVEL"]
return []
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {

View file

@ -1328,13 +1328,17 @@ def _flat_simplify(expression, simplifier, root=True):
return expression
def gen(expression: t.Any) -> str:
def gen(expression: t.Any, comments: bool = False) -> str:
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
Args:
expression: the expression to convert into a SQL string.
comments: whether to include the expression's comments.
"""
return Gen().gen(expression)
return Gen().gen(expression, comments=comments)
class Gen:
@ -1342,7 +1346,7 @@ class Gen:
self.stack = []
self.sqls = []
def gen(self, expression: exp.Expression) -> str:
def gen(self, expression: exp.Expression, comments: bool = False) -> str:
self.stack = [expression]
self.sqls.clear()
@ -1350,6 +1354,9 @@ class Gen:
node = self.stack.pop()
if isinstance(node, exp.Expression):
if comments and node.comments:
self.stack.append(f" /*{','.join(node.comments)}*/")
exp_handler_name = f"{node.key}_sql"
if hasattr(self, exp_handler_name):