Merging upstream version 26.14.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
68f1150572
commit
e9f53ab285
84 changed files with 63872 additions and 61909 deletions
|
@ -23,7 +23,7 @@ def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp
|
|||
|
||||
def _canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = replace_date_funcs(expression, dialect=dialect)
|
||||
expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bools(expression, _replace_int_predicate)
|
||||
|
@ -39,7 +39,7 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|||
return node
|
||||
|
||||
|
||||
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
||||
def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
|
||||
if (
|
||||
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
||||
and not node.expressions
|
||||
|
@ -52,7 +52,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
if not node.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
node = annotate_types(node)
|
||||
node = annotate_types(node, dialect=dialect)
|
||||
return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
|
||||
|
||||
return node
|
||||
|
|
|
@ -237,12 +237,12 @@ def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) ->
|
|||
source, _ = inner_scope.selected_sources[conflict]
|
||||
new_alias = exp.to_identifier(new_name)
|
||||
|
||||
if isinstance(source, exp.Subquery):
|
||||
source.set("alias", exp.TableAlias(this=new_alias))
|
||||
elif isinstance(source, exp.Table) and source.alias:
|
||||
if isinstance(source, exp.Table) and source.alias:
|
||||
source.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table):
|
||||
source.replace(exp.alias_(source, new_alias))
|
||||
elif isinstance(source.parent, exp.Subquery):
|
||||
source.parent.set("alias", exp.TableAlias(this=new_alias))
|
||||
|
||||
for column in inner_scope.source_columns(conflict):
|
||||
column.set("table", exp.to_identifier(new_name))
|
||||
|
|
|
@ -49,25 +49,31 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.SetOperation):
|
||||
left, right = scope.union_scopes
|
||||
if len(left.expression.selects) != len(right.expression.selects):
|
||||
scope_sql = scope.expression.sql()
|
||||
raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
|
||||
set_op = scope.expression
|
||||
if not (set_op.kind or set_op.side):
|
||||
# Do not optimize this set operation if it's using the BigQuery specific
|
||||
# kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
|
||||
left, right = scope.union_scopes
|
||||
if len(left.expression.selects) != len(right.expression.selects):
|
||||
scope_sql = scope.expression.sql()
|
||||
raise OptimizeError(
|
||||
f"Invalid set operation due to column mismatch: {scope_sql}."
|
||||
)
|
||||
|
||||
referenced_columns[left] = parent_selections
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
if scope.expression.args.get("by_name"):
|
||||
referenced_columns[right] = referenced_columns[left]
|
||||
else:
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections
|
||||
or select.alias_or_name in parent_selections
|
||||
]
|
||||
if any(select.is_star for select in right.expression.selects):
|
||||
referenced_columns[right] = parent_selections
|
||||
elif not any(select.is_star for select in left.expression.selects):
|
||||
if scope.expression.args.get("by_name"):
|
||||
referenced_columns[right] = referenced_columns[left]
|
||||
else:
|
||||
referenced_columns[right] = [
|
||||
right.expression.selects[i].alias_or_name
|
||||
for i, select in enumerate(left.expression.selects)
|
||||
if SELECT_ALL in parent_selections
|
||||
or select.alias_or_name in parent_selections
|
||||
]
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
if remove_unused_selections:
|
||||
|
|
|
@ -140,13 +140,14 @@ def validate_qualify_columns(expression: E) -> E:
|
|||
|
||||
|
||||
def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
|
||||
name_column = []
|
||||
field = unpivot.args.get("field")
|
||||
if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
|
||||
name_column.append(field.this)
|
||||
|
||||
name_columns = [
|
||||
field.this
|
||||
for field in unpivot.fields
|
||||
if isinstance(field, exp.In) and isinstance(field.this, exp.Column)
|
||||
]
|
||||
value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
|
||||
return itertools.chain(name_column, value_columns)
|
||||
|
||||
return itertools.chain(name_columns, value_columns)
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
|
||||
|
@ -608,18 +609,19 @@ def _expand_stars(
|
|||
dialect = resolver.schema.dialect
|
||||
|
||||
pivot_output_columns = None
|
||||
pivot_exclude_columns = None
|
||||
pivot_exclude_columns: t.Set[str] = set()
|
||||
|
||||
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
|
||||
if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
|
||||
if pivot.unpivot:
|
||||
pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
|
||||
|
||||
field = pivot.args.get("field")
|
||||
if isinstance(field, exp.In):
|
||||
pivot_exclude_columns = {
|
||||
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
|
||||
}
|
||||
for field in pivot.fields:
|
||||
if isinstance(field, exp.In):
|
||||
pivot_exclude_columns.update(
|
||||
c.output_name for e in field.expressions for c in e.find_all(exp.Column)
|
||||
)
|
||||
|
||||
else:
|
||||
pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
|
||||
|
||||
|
@ -916,6 +918,32 @@ class Resolver:
|
|||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
|
||||
set_op = source.expression
|
||||
|
||||
# BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME
|
||||
on_column_list = set_op.args.get("on")
|
||||
|
||||
if on_column_list:
|
||||
# The resulting columns are the columns in the ON clause:
|
||||
# {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...)
|
||||
columns = [col.name for col in on_column_list]
|
||||
elif set_op.side or set_op.kind:
|
||||
side = set_op.side
|
||||
kind = set_op.kind
|
||||
|
||||
left = set_op.left.named_selects
|
||||
right = set_op.right.named_selects
|
||||
|
||||
# We use dict.fromkeys to deduplicate keys and maintain insertion order
|
||||
if side == "LEFT":
|
||||
columns = left
|
||||
elif side == "FULL":
|
||||
columns = list(dict.fromkeys(left + right))
|
||||
elif kind == "INNER":
|
||||
columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys())
|
||||
else:
|
||||
columns = set_op.named_selects
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import DialectType
|
|||
from sqlglot.helper import csv_reader, name_sequence
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import Schema
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
@ -49,6 +50,7 @@ def qualify_tables(
|
|||
next_alias_name = name_sequence("_q_")
|
||||
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
||||
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _qualify(table: exp.Table) -> None:
|
||||
if isinstance(table.this, exp.Identifier):
|
||||
|
@ -127,8 +129,8 @@ def qualify_tables(
|
|||
if not table_alias.name:
|
||||
table_alias.set("this", exp.to_identifier(next_alias_name()))
|
||||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
column_aliases = dialect.generate_values_aliases(udtf)
|
||||
table_alias.set("columns", column_aliases)
|
||||
else:
|
||||
for node in scope.walk():
|
||||
if (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue