Merging upstream version 11.7.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0c053462ae
commit
8d96084fad
144 changed files with 44104 additions and 39367 deletions
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -19,20 +21,20 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _approx_distinct_sql(self, expression):
|
||||
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
|
||||
accuracy = expression.args.get("accuracy")
|
||||
accuracy = ", " + self.sql(accuracy) if accuracy else ""
|
||||
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
|
||||
|
||||
|
||||
def _datatype_sql(self, expression):
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
sql = self.datatype_sql(expression)
|
||||
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
|
||||
sql = f"{sql} WITH TIME ZONE"
|
||||
return sql
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self, expression):
|
||||
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
|
@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression):
|
|||
return self.lateral_sql(expression)
|
||||
|
||||
|
||||
def _initcap_sql(self, expression):
|
||||
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
|
||||
regex = r"(\w)(\w*)"
|
||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
||||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
|
||||
_ensure_utf8(expression.args["charset"])
|
||||
return f"TO_UTF8({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _no_sort_array(self, expression):
|
||||
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
|
@ -70,49 +72,62 @@ def _no_sort_array(self, expression):
|
|||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self, expression):
|
||||
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
|
||||
if isinstance(expression.parent, exp.Property):
|
||||
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||
return f"ARRAY[{columns}]"
|
||||
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
if expression.parent:
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
if isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(schema.expressions)
|
||||
|
||||
return self.schema_sql(expression)
|
||||
|
||||
|
||||
def _quantile_sql(self, expression):
|
||||
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
|
||||
self.unsupported("Presto does not support exact quantiles")
|
||||
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
def _str_to_time_sql(
|
||||
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
) -> str:
|
||||
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self, expression):
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
this = self.func(
|
||||
"DATE_PARSE",
|
||||
self.func(
|
||||
"SUBSTR",
|
||||
this if this.is_string else exp.cast(this, "VARCHAR"),
|
||||
exp.Literal.number(1),
|
||||
exp.Literal.number(10),
|
||||
),
|
||||
Presto.date_format,
|
||||
)
|
||||
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(expression.text("unit") or "day"),
|
||||
expression.expression,
|
||||
self.func(
|
||||
"DATE_PARSE",
|
||||
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
|
||||
Presto.date_format,
|
||||
),
|
||||
this,
|
||||
)
|
||||
|
||||
|
||||
def _sequence_sql(self, expression):
|
||||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||
|
@ -135,12 +150,12 @@ def _sequence_sql(self, expression):
|
|||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
def _approx_percentile(args):
|
||||
def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 4:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -157,7 +172,7 @@ def _approx_percentile(args):
|
|||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
||||
def _from_unixtime(args):
|
||||
def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -226,11 +241,15 @@ class Presto(Dialect):
|
|||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -246,7 +265,6 @@ class Presto(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -284,6 +302,9 @@ class Presto(Dialect):
|
|||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_qualify, transforms.explode_to_unnest]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
|
@ -308,7 +329,13 @@ class Presto(Dialect):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.lower().startswith("week"):
|
||||
return f"({expression.this.name} * INTERVAL '7' day)"
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue