Adding upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d160fb48f7
commit
36706608dc
89 changed files with 35352 additions and 33081 deletions
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self, expression):
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
|
@ -47,7 +49,7 @@ def _add_date_sql(self, expression):
|
|||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
|
||||
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
|
||||
|
@ -56,21 +58,21 @@ def _date_diff_sql(self, expression):
|
|||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _array_sort(self, expression):
|
||||
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
return f"SORT_ARRAY({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _property_sql(self, expression):
|
||||
def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
||||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -78,7 +80,7 @@ def _str_to_date(self, expression):
|
|||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
||||
def _str_to_time(self, expression):
|
||||
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -86,20 +88,22 @@ def _str_to_time(self, expression):
|
|||
return f"CAST({this} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def _time_format(self, expression):
|
||||
def _time_format(
|
||||
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
|
||||
) -> t.Optional[str]:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.time_format:
|
||||
return None
|
||||
return time_format
|
||||
|
||||
|
||||
def _time_to_str(self, expression):
|
||||
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
return f"DATE_FORMAT({this}, {time_format})"
|
||||
|
||||
|
||||
def _to_date_sql(self, expression):
|
||||
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -107,7 +111,7 @@ def _to_date_sql(self, expression):
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _unnest_to_explode_sql(self, expression):
|
||||
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
|
||||
unnest = expression.this
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
|
@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression):
|
|||
exp.Lateral(
|
||||
this=udtf(this=expression),
|
||||
view=True,
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]),
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
)
|
||||
)
|
||||
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
|
||||
|
@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression):
|
|||
return self.join_sql(expression)
|
||||
|
||||
|
||||
def _index_sql(self, expression):
|
||||
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = self.sql(expression, "table")
|
||||
columns = self.sql(expression, "columns")
|
||||
|
@ -263,14 +267,15 @@ class Hive(Dialect):
|
|||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
|
@ -333,13 +338,19 @@ class Hive(Dialect):
|
|||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
def with_properties(self, properties):
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
|
||||
)
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(
|
||||
properties,
|
||||
prefix=self.seg("TBLPROPERTIES"),
|
||||
)
|
||||
|
||||
def datatype_sql(self, expression):
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
if (
|
||||
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
|
||||
and not expression.expressions
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue