1
0
Fork 0

Adding upstream version 18.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:03:05 +01:00
parent c4fc25c23b
commit be16920347
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
96 changed files with 59037 additions and 52828 deletions

View file

@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
expression.args["joins"].remove(join)
@ -163,65 +163,134 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
return expression
def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope
taken_select_names = set(expression.named_selects)
taken_source_names = {name for name, _ in Scope(expression).references}
taken_select_names = set(expression.named_selects)
taken_source_names = {name for name, _ in Scope(expression).references}
for select in expression.selects:
to_replace = select
def new_name(names: t.Set[str], name: str) -> str:
name = find_new_name(names, name)
names.add(name)
return name
pos_alias = ""
explode_alias = ""
arrays: t.List[exp.Condition] = []
series_alias = new_name(taken_select_names, "pos")
series = exp.alias_(
exp.Unnest(
expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
),
new_name(taken_source_names, "_u"),
table=[series_alias],
)
if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this
# we use list here because expression.selects is mutated inside the loop
for select in expression.selects.copy():
explode = select.find(exp.Explode, exp.Posexplode)
if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)
if isinstance(explode, (exp.Explode, exp.Posexplode)):
pos_alias = ""
explode_alias = ""
explode_arg = select.this
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
if isinstance(select, exp.Alias):
explode_alias = select.alias
alias = select
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
explode = alias.find(exp.Explode, exp.Posexplode)
assert explode
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
is_posexplode = isinstance(explode, exp.Posexplode)
explode_arg = explode.this
unnest_source_alias = find_new_name(taken_source_names, "_u")
taken_source_names.add(unnest_source_alias)
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)
if not explode_alias:
explode_alias = find_new_name(taken_select_names, "col")
taken_select_names.add(explode_alias)
unnest_source_alias = new_name(taken_source_names, "_u")
if not explode_alias:
explode_alias = new_name(taken_select_names, "col")
if is_posexplode:
pos_alias = new_name(taken_select_names, "pos")
if not pos_alias:
pos_alias = new_name(taken_select_names, "pos")
alias.set("alias", exp.to_identifier(explode_alias))
column = exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(explode_alias),
)
explode.replace(column)
if is_posexplode:
pos_alias = find_new_name(taken_select_names, "pos")
taken_select_names.add(pos_alias)
expressions = expression.expressions
expressions.insert(
expressions.index(alias) + 1,
exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(pos_alias),
).as_(pos_alias),
)
expression.set("expressions", expressions)
if is_posexplode:
column_names = [explode_alias, pos_alias]
to_replace.pop()
expression.select(pos_alias, explode_alias, copy=False)
else:
column_names = [explode_alias]
to_replace.replace(exp.column(explode_alias))
if not arrays:
if expression.args.get("from"):
expression.join(series, copy=False)
else:
expression.from_(series, copy=False)
unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
arrays.append(size)
if not expression.args.get("from"):
expression.from_(unnest, copy=False)
else:
expression.join(unnest, join_type="CROSS", copy=False)
# trino doesn't support left join unnest with on conditions
# if it did, this would be much simpler
expression.join(
exp.alias_(
exp.Unnest(
expressions=[explode_arg.copy()],
offset=exp.to_identifier(pos_alias),
),
unnest_source_alias,
table=[explode_alias],
),
join_type="CROSS",
copy=False,
)
return expression
if index_offset != 1:
size = size - 1
expression.where(
exp.column(series_alias)
.eq(exp.column(pos_alias))
.or_(
(exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
),
copy=False,
)
if arrays:
end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
if index_offset != 1:
end = end - (1 - index_offset)
series.expressions[0].set("end", end)
return expression
return _explode_to_unnest
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
@ -283,6 +352,31 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
return expression
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Timestamp) and not expression.expression:
return exp.cast(
expression.this,
to=exp.DataType.Type.TIMESTAMP,
)
return expression
def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
on = join.args.get("on")
if on and join.kind in ("SEMI", "ANTI"):
subquery = exp.select("1").from_(join.this).where(on)
exists = exp.Exists(this=subquery)
if join.kind == "ANTI":
exists = exists.not_(copy=False)
join.pop()
expression.where(exists, copy=False)
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
@ -327,12 +421,3 @@ def preprocess(
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
return _to_sql
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Timestamp) and not expression.expression:
return exp.cast(
expression.this,
to=exp.DataType.Type.TIMESTAMP,
)
return expression