1
0
Fork 0

Merging upstream version 12.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:53:39 +01:00
parent fffa0d5761
commit 62b2b24d3b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
100 changed files with 35022 additions and 30936 deletions

View file

@ -153,7 +153,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
for condition in on.flatten():
if isinstance(condition, exp.EQ):

View file

@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
copy=False,
)
return a

View file

@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@ -30,7 +29,6 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
expand_laterals,
pushdown_projections,
validate_qualify_columns,
normalize,

View file

@ -3,11 +3,12 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
def qualify_columns(expression, schema):
def qualify_columns(expression, schema, expand_laterals=True):
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
"""
schema = ensure_schema(schema)
if not schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
if schema.mapping and expand_laterals:
expression = _expand_laterals(expression)
return expression
@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
join.set("on", exp.and_(*conditions, copy=False))
if column_tables:
for column in scope.columns:
@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources:
elif column_table not in scope.sources and (
not scope.parent or column_table not in scope.parent.sources
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@ -376,10 +385,13 @@ def _qualify_outputs(scope):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias) and not selection.is_star:
alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
alias_.set("this", selection)
selection = alias_
selection = alias(
selection,
alias=selection.output_name or f"_col_{i}",
quoted=True
if isinstance(selection, exp.Column) and selection.this.quoted
else None,
)
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))

View file

@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
Example:
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()
@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
# Expand join construct
if isinstance(derived_table, exp.Subquery):
unnested = derived_table.unnest()
if isinstance(unnested, exp.Table):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))

View file

@ -510,6 +510,9 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@ -587,6 +590,9 @@ def _traverse_tables(scope):
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:

View file

@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
return expression
@ -76,9 +77,17 @@ def simplify_not(expression):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
return exp.or_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
return exp.and_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
if is_null(condition):
return exp.null()
if always_true(expression.this):
@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(e for _, e in sorted(arr)))
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
expression = result_func(*deduped.values())
expression = result_func(*deduped.values(), copy=False)
return expression