1
0
Fork 0

Adding upstream version 11.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:51:35 +01:00
parent b4e0e3422e
commit 82a8846a46
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 44104 additions and 39367 deletions

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
def _explode_to_unnest_sql(self, expression):
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
return self.lateral_sql(expression)
def _initcap_sql(self, expression):
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
_ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
_ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression):
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self, expression):
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
if expression.parent:
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
def _quantile_sql(self, expression):
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(self, expression):
def _str_to_time_sql(
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
def _ts_or_ds_add_sql(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
this = self.func(
"DATE_PARSE",
self.func(
"SUBSTR",
this if this.is_string else exp.cast(this, "VARCHAR"),
exp.Literal.number(1),
exp.Literal.number(10),
),
Presto.date_format,
)
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
self.func(
"DATE_PARSE",
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
Presto.date_format,
),
this,
)
def _sequence_sql(self, expression):
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
return self.func("SEQUENCE", start, end, step)
def _ensure_utf8(charset):
def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
def _approx_percentile(args: t.Sequence) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@ -157,7 +172,7 @@ def _approx_percentile(args):
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
def _from_unixtime(args: t.Sequence) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@ -226,11 +241,15 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = {
@ -246,7 +265,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -284,6 +302,9 @@ class Presto(Dialect):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
@ -308,7 +329,13 @@ class Presto(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if expression.this and unit.lower().startswith("week"):
return f"({expression.this.name} * INTERVAL '7' day)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"