Merging upstream version 12.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fffa0d5761
commit
62b2b24d3b
100 changed files with 35022 additions and 30936 deletions
|
@ -153,7 +153,7 @@ def join_condition(join):
|
|||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
|
||||
|
||||
for condition in on.flatten():
|
||||
if isinstance(condition, exp.EQ):
|
||||
|
|
|
@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
|
|||
for column in projection.find_all(exp.Column):
|
||||
if not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = projection.this
|
||||
return expression
|
||||
|
|
|
@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
|
|||
lambda c: to_func(
|
||||
uniq_sort(flatten(from_func(c, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(c, b.right)), cache),
|
||||
copy=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), cache),
|
||||
uniq_sort(flatten(from_func(a, b.right)), cache),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return a
|
||||
|
|
|
@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
|
|||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||
from sqlglot.optimizer.lower_identities import lower_identities
|
||||
|
@ -30,7 +29,6 @@ RULES = (
|
|||
qualify_tables,
|
||||
isolate_table_selects,
|
||||
qualify_columns,
|
||||
expand_laterals,
|
||||
pushdown_projections,
|
||||
validate_qualify_columns,
|
||||
normalize,
|
||||
|
|
|
@ -3,11 +3,12 @@ import typing as t
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
def qualify_columns(expression, schema, expand_laterals=True):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
||||
|
@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
|
|||
"""
|
||||
schema = ensure_schema(schema)
|
||||
|
||||
if not schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema)
|
||||
_pop_table_column_aliases(scope.ctes)
|
||||
|
@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
|
|||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
if schema.mapping and expand_laterals:
|
||||
expression = _expand_laterals(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
|
|||
tables[join_table] = None
|
||||
|
||||
join.args.pop("using")
|
||||
join.set("on", exp.and_(*conditions))
|
||||
join.set("on", exp.and_(*conditions, copy=False))
|
||||
|
||||
if column_tables:
|
||||
for column in scope.columns:
|
||||
|
@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
|
|||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources:
|
||||
elif column_table not in scope.sources and (
|
||||
not scope.parent or column_table not in scope.parent.sources
|
||||
):
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
||||
|
@ -376,10 +385,13 @@ def _qualify_outputs(scope):
|
|||
if not selection.output_name:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias) and not selection.is_star:
|
||||
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
|
||||
selection = alias(
|
||||
selection,
|
||||
alias=selection.output_name or f"_col_{i}",
|
||||
quoted=True
|
||||
if isinstance(selection, exp.Column) and selection.this.quoted
|
||||
else None,
|
||||
)
|
||||
if aliased_column:
|
||||
selection.set("alias", exp.to_identifier(aliased_column))
|
||||
|
||||
|
|
|
@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
|
||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified tables.
|
||||
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
|
||||
replaces "join constructs" (*) by equivalent SELECT * subqueries.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
|
||||
>>> qualify_tables(expression, db="db").sql()
|
||||
'SELECT 1 FROM db.tbl AS tbl'
|
||||
>>>
|
||||
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
|
||||
>>> qualify_tables(expression).sql()
|
||||
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to qualify
|
||||
db (str): Database name
|
||||
catalog (str): Catalog name
|
||||
schema: A schema to populate
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: qualified expression
|
||||
|
||||
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
|
||||
"""
|
||||
sequence = itertools.count()
|
||||
|
||||
|
@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
|
||||
for scope in traverse_scope(expression):
|
||||
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
|
||||
# Expand join construct
|
||||
if isinstance(derived_table, exp.Subquery):
|
||||
unnested = derived_table.unnest()
|
||||
if isinstance(unnested, exp.Table):
|
||||
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
|
||||
|
||||
if not derived_table.args.get("alias"):
|
||||
alias_ = f"_q_{next(sequence)}"
|
||||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
|
|
|
@ -510,6 +510,9 @@ def _traverse_scope(scope):
|
|||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
pass
|
||||
else:
|
||||
|
@ -587,6 +590,9 @@ def _traverse_tables(scope):
|
|||
for join in scope.expression.args.get("joins") or []:
|
||||
expressions.append(join.this)
|
||||
|
||||
if isinstance(scope.expression, exp.Table):
|
||||
expressions.append(scope.expression)
|
||||
|
||||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
|
|
|
@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
|||
return exp.and_(
|
||||
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||
copy=False,
|
||||
)
|
||||
return expression
|
||||
|
||||
|
@ -76,9 +77,17 @@ def simplify_not(expression):
|
|||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
return exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
|
||||
return exp.and_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
|
@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
|
|||
# A AND C AND B -> A AND B AND C
|
||||
for i, (sql, e) in enumerate(arr[1:]):
|
||||
if sql < arr[i][0]:
|
||||
expression = result_func(*(e for _, e in sorted(arr)))
|
||||
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
|
||||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values())
|
||||
expression = result_func(*deduped.values(), copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue