1
0
Fork 0

Merging upstream version 25.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:51:42 +01:00
parent dba379232c
commit aa0eae236a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
102 changed files with 52995 additions and 52070 deletions

View file

@ -28,6 +28,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_str,
sequence_sql,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
@ -204,11 +205,11 @@ def _jsonextract_sql(self: Presto.Generator, expression: exp.JSONExtract) -> str
return f"{this}{expr}"
def _to_int(expression: exp.Expression) -> exp.Expression:
def _to_int(self: Presto.Generator, expression: exp.Expression) -> exp.Expression:
if not expression.type:
from sqlglot.optimizer.annotate_types import annotate_types
annotate_types(expression)
annotate_types(expression, dialect=self.dialect)
if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
return exp.cast(expression, to=exp.DataType.Type.BIGINT)
return expression
@ -229,7 +230,7 @@ def _date_delta_sql(
name: str, negate_interval: bool = False
) -> t.Callable[[Presto.Generator, DATE_ADD_OR_SUB], str]:
def _delta_sql(self: Presto.Generator, expression: DATE_ADD_OR_SUB) -> str:
interval = _to_int(expression.expression)
interval = _to_int(self, expression.expression)
return self.func(
name,
unit_to_str(expression),
@ -256,6 +257,21 @@ class Presto(Dialect):
# https://github.com/prestodb/presto/issues/2863
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
# The result of certain math functions in Presto/Trino is of type
# equal to the input type e.g: FLOOR(5.5/2) -> DECIMAL, FLOOR(5/2) -> BIGINT
ANNOTATORS = {
**Dialect.ANNOTATORS,
exp.Floor: lambda self, e: self._annotate_by_args(e, "this"),
exp.Ceil: lambda self, e: self._annotate_by_args(e, "this"),
exp.Mod: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Round: lambda self, e: self._annotate_by_args(e, "this"),
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Rand: lambda self, e: self._annotate_by_args(e, "this")
if e.this
else self._set_type(e, exp.DataType.Type.DOUBLE),
}
class Tokenizer(tokens.Tokenizer):
UNICODE_STRINGS = [
(prefix + q, q)
@ -420,6 +436,7 @@ class Presto(Dialect):
exp.FirstValue: _first_last_sql,
exp.FromTimeZone: lambda self,
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GenerateSeries: sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@ -572,11 +589,20 @@ class Presto(Dialect):
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
this = expression.this
value_as_text = exp.cast(this, exp.DataType.Type.TEXT)
value_as_timestamp = (
exp.cast(this, exp.DataType.Type.TIMESTAMP) if this.is_string else this
)
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
formatted_value = self.func(
"DATE_FORMAT", value_as_timestamp, self.format_time(expression)
)
parse_with_tz = self.func(
"PARSE_DATETIME",
value_as_text,
formatted_value,
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
)
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
@ -636,26 +662,6 @@ class Presto(Dialect):
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
else:
target_type = None
if target_type and target_type.is_type("timestamp"):
if target_type is start.to:
end = exp.cast(end, target_type)
else:
start = exp.cast(start, target_type)
return self.func("SEQUENCE", start, end, step)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]: