Merging upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ecb42ec17f
commit
63746a3e92
89 changed files with 35352 additions and 33081 deletions
|
@ -1,9 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import should_identify
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
|
||||
"""Converts a sql expression into a standard form.
|
||||
|
||||
This method relies on annotate_types because many of the
|
||||
|
@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
identify: Whether or not to force identify identifier.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize)
|
||||
exp.replace_children(expression, canonicalize, identify=identify)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
|
||||
if isinstance(expression, exp.Identifier):
|
||||
expression.set("quoted", True)
|
||||
if should_identify(expression.this, identify):
|
||||
expression.set("quoted", True)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Connector):
|
||||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
_replace_int_predicate(expression.this)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if (
|
||||
|
@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
|
|||
cast = exp.Cast(this=node.copy(), to=data_type)
|
||||
cast.type = data_type
|
||||
node.replace(cast)
|
||||
|
||||
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import flatten
|
||||
from sqlglot.optimizer.qualify_columns import Resolver
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
else:
|
||||
order_refs = set()
|
||||
|
||||
new_selections = defaultdict(list)
|
||||
new_selections = []
|
||||
removed = False
|
||||
star = False
|
||||
|
||||
for selection in scope.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
new_selections[name].append(selection)
|
||||
new_selections.append(selection)
|
||||
else:
|
||||
if selection.is_star:
|
||||
star = True
|
||||
|
@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
names = {s.alias_or_name for s in new_selections}
|
||||
|
||||
for name in sorted(parent_selections):
|
||||
if name not in new_selections:
|
||||
new_selections[name].append(
|
||||
alias(exp.column(name, table=resolver.get_table(name)), name)
|
||||
)
|
||||
if name not in names:
|
||||
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections[""].append(DEFAULT_SELECTION())
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
|
||||
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
|
||||
scope.expression.select(*new_selections, append=False, copy=False)
|
||||
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
|
|
|
@ -37,6 +37,7 @@ def qualify_columns(expression, schema):
|
|||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver):
|
|||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources:
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
||||
root, *parts = column.parts
|
||||
|
||||
if root.name in scope.sources:
|
||||
# struct is already qualified, but we still need to change the AST representation
|
||||
column_table = root
|
||||
root, *parts = parts
|
||||
else:
|
||||
column_table = resolver.get_table(root.name)
|
||||
|
||||
if column_table:
|
||||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
columns_missing_from_scope = []
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
|
@ -373,10 +389,14 @@ class Resolver:
|
|||
if isinstance(node, exp.Subqueryable):
|
||||
while node and node.alias != table_name:
|
||||
node = node.parent
|
||||
|
||||
node_alias = node.args.get("alias")
|
||||
if node_alias:
|
||||
return node_alias.this
|
||||
return exp.to_identifier(table_name)
|
||||
|
||||
return exp.to_identifier(
|
||||
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
|
||||
)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
|
|
|
@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
scope.rename_source(None, alias_)
|
||||
|
||||
for source in scope.sources.values():
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
identifier = isinstance(source.this, exp.Identifier)
|
||||
|
||||
if identifier:
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
if not source.args.get("catalog"):
|
||||
|
@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
source = source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
source.this if identifier else next_name(),
|
||||
name if name else next_name(),
|
||||
table=True,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import find_new_name
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
|
@ -293,6 +294,8 @@ class Scope:
|
|||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
if name in result:
|
||||
raise OptimizeError(f"Alias already used: {name}")
|
||||
if name in self.sources:
|
||||
result[name] = (node, self.sources[name])
|
||||
|
||||
|
@ -594,6 +597,8 @@ def _traverse_tables(scope):
|
|||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in sources:
|
||||
sources[find_new_name(sources, table_name)] = expression
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
continue
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue