Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8deb804d23
commit
fc63828ee4
167 changed files with 58268 additions and 51337 deletions
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
if_sql,
|
||||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
rename_func,
|
||||
struct_extract_sql,
|
||||
|
@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
)
|
||||
|
||||
|
||||
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
target_type = None
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
|
||||
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
sql = self.func("SEQUENCE", start, end, step)
|
||||
if isinstance(expression.parent, exp.Table):
|
||||
sql = f"UNNEST({sql})"
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def _ensure_utf8(charset: exp.Literal) -> None:
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
||||
def _approx_percentile(args: t.List) -> exp.Expression:
|
||||
if len(args) == 4:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression:
|
|||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
||||
def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
||||
def _from_unixtime(args: t.List) -> exp.Expression:
|
||||
if len(args) == 3:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Table):
|
||||
if isinstance(expression.this, exp.GenerateSeries):
|
||||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
|
||||
if expression.alias:
|
||||
return exp.alias_(
|
||||
unnest,
|
||||
alias="_u",
|
||||
table=[expression.alias],
|
||||
copy=False,
|
||||
)
|
||||
return unnest
|
||||
return expression
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
time_format = MySQL.time_format # type: ignore
|
||||
time_mapping = MySQL.time_mapping # type: ignore
|
||||
time_format = MySQL.time_format
|
||||
time_mapping = MySQL.time_mapping
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"ROW": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
|
@ -252,13 +243,13 @@ class Presto(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
|
@ -268,8 +259,9 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
|
@ -293,7 +285,7 @@ class Presto(Dialect):
|
|||
exp.Decode: _decode_sql,
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.GenerateSeries: _sequence_sql,
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
|
@ -301,10 +293,10 @@ class Presto(Dialect):
|
|||
exp.Initcap: _initcap_sql,
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -320,8 +312,7 @@ class Presto(Dialect):
|
|||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
@ -336,6 +327,7 @@ class Presto(Dialect):
|
|||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
|
@ -351,3 +343,25 @@ class Presto(Dialect):
|
|||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
||||
def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
|
||||
start = expression.args["start"]
|
||||
end = expression.args["end"]
|
||||
step = expression.args.get("step")
|
||||
|
||||
if isinstance(start, exp.Cast):
|
||||
target_type = start.to
|
||||
elif isinstance(end, exp.Cast):
|
||||
target_type = end.to
|
||||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue