1
0
Fork 0

Adding upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:46 +01:00
parent 6a89523da4
commit 5bd573dda1
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions

View file

@ -5,9 +5,11 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
binary_from_function,
bool_xor_sql,
date_trunc_to_time,
datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
if_sql,
@ -22,6 +24,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
@ -95,17 +98,16 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
expression = ts_or_ds_add_cast(expression)
unit = exp.Literal.string(expression.text("unit") or "day")
return self.func("DATE_ADD", unit, expression.expression, expression.this)
if not isinstance(this, exp.CurrentDate):
this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
this,
)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
unit = exp.Literal.string(expression.text("unit") or "day")
return self.func("DATE_DIFF", unit, expr, this)
def _approx_percentile(args: t.List) -> exp.Expression:
@ -136,11 +138,11 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
def _parse_element_at(args: t.List) -> exp.SafeBracket:
def _parse_element_at(args: t.List) -> exp.Bracket:
this = seq_get(args, 0)
index = seq_get(args, 1)
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
@ -168,6 +170,22 @@ def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) ->
return rename_func("ARBITRARY")(self, expression)
def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return rename_func("FROM_UNIXTIME")(self, expression)
if scale == exp.UnixToTime.MILLIS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
if scale == exp.UnixToTime.MICROS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
if scale == exp.UnixToTime.NANOS:
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
self.unsupported(f"Unsupported scale for timestamp: {scale}.")
return ""
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@ -175,11 +193,12 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@ -229,6 +248,7 @@ class Presto(Dialect):
),
"ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"SET_AGG": exp.ArrayUniqueAgg.from_arg_list,
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
@ -253,6 +273,7 @@ class Presto(Dialect):
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
SUPPORTS_SINGLE_ARG_CONCAT = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -284,6 +305,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
@ -298,7 +320,7 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
@ -330,9 +352,6 @@ class Presto(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.Right: right_to_substring_sql,
exp.SafeBracket: lambda self, e: self.func(
"ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@ -361,10 +380,11 @@ class Presto(Dialect):
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
@ -374,8 +394,24 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
return self.func(
"ELEMENT_AT",
expression.this,
seq_get(
apply_index_offset(
expression.this,
expression.expressions,
1 - expression.args.get("offset", 0),
),
0,
),
)
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions):
if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)