1
0
Fork 0

Merging upstream version 11.4.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:46:19 +01:00
parent ecb42ec17f
commit 63746a3e92
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
89 changed files with 35352 additions and 33081 deletions

View file

@ -1,9 +1,12 @@
from __future__ import annotations
import itertools
from sqlglot import exp
from sqlglot.helper import should_identify
def canonicalize(expression: exp.Expression) -> exp.Expression:
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize)
exp.replace_children(expression, canonicalize, identify=identify)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
if should_identify(expression.this, identify):
expression.set("quoted", True)
return expression
@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Connector):
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
elif isinstance(expression, (exp.Where, exp.Having)):
_replace_int_predicate(expression.this)
return expression
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
def _replace_int_predicate(expression: exp.Expression) -> None:
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))

View file

@ -1,7 +1,6 @@
from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.helper import flatten
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema):
else:
order_refs = set()
new_selections = defaultdict(list)
new_selections = []
removed = False
star = False
for selection in scope.selects:
name = selection.alias_or_name
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
new_selections[name].append(selection)
new_selections.append(selection)
else:
if selection.is_star:
star = True
@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema):
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
for name in sorted(parent_selections):
if name not in new_selections:
new_selections[name].append(
alias(exp.column(name, table=resolver.get_table(name)), name)
)
if name not in names:
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections[""].append(DEFAULT_SELECTION())
new_selections.append(DEFAULT_SELECTION())
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
scope.expression.select(*new_selections, append=False, copy=False)
if removed:
scope.clear_cache()

View file

@ -37,6 +37,7 @@ def qualify_columns(expression, schema):
_qualify_outputs(scope)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
return expression
@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources:
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
root, *parts = column.parts
if root.name in scope.sources:
# struct is already qualified, but we still need to change the AST representation
column_table = root
root, *parts = parts
else:
column_table = resolver.get_table(root.name)
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
@ -373,10 +389,14 @@ class Resolver:
if isinstance(node, exp.Subqueryable):
while node and node.alias != table_name:
node = node.parent
node_alias = node.args.get("alias")
if node_alias:
return node_alias.this
return exp.to_identifier(table_name)
return exp.to_identifier(
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
)
@property
def all_columns(self):

View file

@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
for source in scope.sources.values():
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
identifier = isinstance(source.this, exp.Identifier)
if identifier:
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
source.this if identifier else next_name(),
name if name else next_name(),
table=True,
)
)

View file

@ -4,6 +4,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
class ScopeType(Enum):
@ -293,6 +294,8 @@ class Scope:
result = {}
for name, node in referenced_names:
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
result[name] = (node, self.sources[name])
@ -594,6 +597,8 @@ def _traverse_tables(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
sources[source_name] = expression
continue