1
0
Fork 0

Merging upstream version 26.14.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-04-16 09:04:43 +02:00
parent 68f1150572
commit e9f53ab285
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
84 changed files with 63872 additions and 61909 deletions

View file

@ -23,7 +23,7 @@ def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp
def _canonicalize(expression: exp.Expression) -> exp.Expression:
expression = add_text_to_concat(expression)
expression = replace_date_funcs(expression)
expression = replace_date_funcs(expression, dialect=dialect)
expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
expression = remove_redundant_casts(expression)
expression = ensure_bools(expression, _replace_int_predicate)
@ -39,7 +39,7 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
return node
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
if (
isinstance(node, (exp.Date, exp.TsOrDsToDate))
and not node.expressions
@ -52,7 +52,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
if not node.type:
from sqlglot.optimizer.annotate_types import annotate_types
node = annotate_types(node)
node = annotate_types(node, dialect=dialect)
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
return node

View file

@ -237,12 +237,12 @@ def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) ->
source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name)
if isinstance(source, exp.Subquery):
source.set("alias", exp.TableAlias(this=new_alias))
elif isinstance(source, exp.Table) and source.alias:
if isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source, new_alias))
elif isinstance(source.parent, exp.Subquery):
source.parent.set("alias", exp.TableAlias(this=new_alias))
for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))

View file

@ -49,25 +49,31 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.SetOperation):
left, right = scope.union_scopes
if len(left.expression.selects) != len(right.expression.selects):
scope_sql = scope.expression.sql()
raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
set_op = scope.expression
if not (set_op.kind or set_op.side):
# Do not optimize this set operation if it's using the BigQuery specific
# kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
left, right = scope.union_scopes
if len(left.expression.selects) != len(right.expression.selects):
scope_sql = scope.expression.sql()
raise OptimizeError(
f"Invalid set operation due to column mismatch: {scope_sql}."
)
referenced_columns[left] = parent_selections
referenced_columns[left] = parent_selections
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
if any(select.is_star for select in right.expression.selects):
referenced_columns[right] = parent_selections
elif not any(select.is_star for select in left.expression.selects):
if scope.expression.args.get("by_name"):
referenced_columns[right] = referenced_columns[left]
else:
referenced_columns[right] = [
right.expression.selects[i].alias_or_name
for i, select in enumerate(left.expression.selects)
if SELECT_ALL in parent_selections
or select.alias_or_name in parent_selections
]
if isinstance(scope.expression, exp.Select):
if remove_unused_selections:

View file

@ -140,13 +140,14 @@ def validate_qualify_columns(expression: E) -> E:
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
name_column = []
field = unpivot.args.get("field")
if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
name_column.append(field.this)
name_columns = [
field.this
for field in unpivot.fields
if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
]
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
return itertools.chain(name_column, value_columns)
return itertools.chain(name_columns, value_columns)
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
@ -608,18 +609,19 @@ def _expand_stars(
dialect = resolver.schema.dialect
pivot_output_columns = None
pivot_exclude_columns = None
pivot_exclude_columns: t.Set[str] = set()
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
if pivot.unpivot:
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
field = pivot.args.get("field")
if isinstance(field, exp.In):
pivot_exclude_columns = {
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
}
for field in pivot.fields:
if isinstance(field, exp.In):
pivot_exclude_columns.update(
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
)
else:
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
@ -916,6 +918,32 @@ class Resolver:
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
set_op = source.expression
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
on_column_list = set_op.args.get("on")
if on_column_list:
# The resulting columns are the columns in the ON clause:
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
columns = [col.name for col in on_column_list]
elif set_op.side or set_op.kind:
side = set_op.side
kind = set_op.kind
left = set_op.left.named_selects
right = set_op.right.named_selects
# We use dict.fromkeys to deduplicate keys and maintain insertion order
if side == "LEFT":
columns = left
elif side == "FULL":
columns = list(dict.fromkeys(left + right))
elif kind == "INNER":
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
else:
columns = set_op.named_selects
else:
columns = source.expression.named_selects

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
from sqlglot.dialects.dialect import Dialect
if t.TYPE_CHECKING:
from sqlglot._typing import E
@ -49,6 +50,7 @@ def qualify_tables(
next_alias_name = name_sequence("_q_")
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
dialect = Dialect.get_or_raise(dialect)
def _qualify(table: exp.Table) -> None:
if isinstance(table.this, exp.Identifier):
@ -127,8 +129,8 @@ def qualify_tables(
if not table_alias.name:
table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
column_aliases = dialect.generate_values_aliases(udtf)
table_alias.set("columns", column_aliases)
else:
for node in scope.walk():
if (