2025-02-13 15:52:54 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing as t
|
|
|
|
|
2025-02-13 20:46:23 +01:00
|
|
|
from sqlglot import exp, transforms
|
2025-02-13 15:56:32 +01:00
|
|
|
from sqlglot.dialects.dialect import (
|
2025-02-13 20:46:23 +01:00
|
|
|
binary_from_function,
|
|
|
|
format_time_lambda,
|
2025-02-13 21:02:03 +01:00
|
|
|
is_parse_json,
|
2025-02-13 15:56:32 +01:00
|
|
|
pivot_column_names,
|
|
|
|
rename_func,
|
|
|
|
trim_sql,
|
|
|
|
)
|
2025-02-13 15:52:54 +01:00
|
|
|
from sqlglot.dialects.hive import Hive
|
|
|
|
from sqlglot.helper import seq_get
|
|
|
|
|
|
|
|
|
2025-02-13 20:56:33 +01:00
|
|
|
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
|
|
|
|
keys = expression.args.get("keys")
|
|
|
|
values = expression.args.get("values")
|
|
|
|
|
|
|
|
if not keys or not values:
|
|
|
|
return "MAP()"
|
|
|
|
|
|
|
|
return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
|
2025-02-13 15:52:54 +01:00
|
|
|
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
|
|
|
|
|
|
|
|
2025-02-13 20:56:33 +01:00
|
|
|
def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
|
2025-02-13 15:52:54 +01:00
|
|
|
this = self.sql(expression, "this")
|
|
|
|
time_format = self.format_time(expression)
|
2025-02-13 16:00:14 +01:00
|
|
|
if time_format == Hive.DATE_FORMAT:
|
2025-02-13 15:52:54 +01:00
|
|
|
return f"TO_DATE({this})"
|
|
|
|
return f"TO_DATE({this}, {time_format})"
|
|
|
|
|
|
|
|
|
2025-02-13 20:56:33 +01:00
|
|
|
def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
|
2025-02-13 15:52:54 +01:00
|
|
|
scale = expression.args.get("scale")
|
|
|
|
timestamp = self.sql(expression, "this")
|
|
|
|
if scale is None:
|
|
|
|
return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
|
|
|
|
if scale == exp.UnixToTime.SECONDS:
|
|
|
|
return f"TIMESTAMP_SECONDS({timestamp})"
|
|
|
|
if scale == exp.UnixToTime.MILLIS:
|
|
|
|
return f"TIMESTAMP_MILLIS({timestamp})"
|
|
|
|
if scale == exp.UnixToTime.MICROS:
|
|
|
|
return f"TIMESTAMP_MICROS({timestamp})"
|
|
|
|
|
2025-02-13 21:18:57 +01:00
|
|
|
return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
|
|
|
"""
|
|
|
|
Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a
|
|
|
|
pivoted source in a subquery with the same alias to preserve the query's semantics.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> from sqlglot import parse_one
|
|
|
|
>>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv")
|
|
|
|
>>> print(_unalias_pivot(expr).sql(dialect="spark"))
|
|
|
|
SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv
|
|
|
|
"""
|
|
|
|
if isinstance(expression, exp.From) and expression.this.args.get("pivots"):
|
|
|
|
pivot = expression.this.args["pivots"][0]
|
|
|
|
if pivot.alias:
|
|
|
|
alias = pivot.args["alias"].pop()
|
|
|
|
return exp.From(
|
|
|
|
this=expression.this.replace(
|
2025-02-13 21:15:38 +01:00
|
|
|
exp.select("*")
|
|
|
|
.from_(expression.this.copy(), copy=False)
|
|
|
|
.subquery(alias=alias, copy=False)
|
2025-02-13 15:56:32 +01:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return expression
|
|
|
|
|
|
|
|
|
|
|
|
def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|
|
|
"""
|
|
|
|
Spark doesn't allow the column referenced in the PIVOT's field to be qualified,
|
|
|
|
so we need to unqualify it.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> from sqlglot import parse_one
|
|
|
|
>>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))")
|
|
|
|
>>> print(_unqualify_pivot_columns(expr).sql(dialect="spark"))
|
|
|
|
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
|
|
|
|
"""
|
|
|
|
if isinstance(expression, exp.Pivot):
|
2025-02-13 21:18:57 +01:00
|
|
|
expression.set("field", transforms.unqualify_columns(expression.args["field"]))
|
2025-02-13 15:56:32 +01:00
|
|
|
|
|
|
|
return expression
|
|
|
|
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
class Spark2(Hive):
|
|
|
|
class Parser(Hive.Parser):
|
2025-02-13 21:04:14 +01:00
|
|
|
TRIM_PATTERN_FIRST = True
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
FUNCTIONS = {
|
2025-02-13 15:56:32 +01:00
|
|
|
**Hive.Parser.FUNCTIONS,
|
2025-02-13 15:52:54 +01:00
|
|
|
"AGGREGATE": exp.Reduce.from_arg_list,
|
2025-02-13 20:46:23 +01:00
|
|
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
2025-02-13 15:52:54 +01:00
|
|
|
"BOOLEAN": _parse_as_cast("boolean"),
|
2025-02-13 16:00:14 +01:00
|
|
|
"DATE": _parse_as_cast("date"),
|
2025-02-13 20:46:23 +01:00
|
|
|
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
|
|
|
this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
|
|
|
|
),
|
|
|
|
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
|
|
|
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
|
|
|
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
2025-02-13 15:52:54 +01:00
|
|
|
"DOUBLE": _parse_as_cast("double"),
|
|
|
|
"FLOAT": _parse_as_cast("float"),
|
2025-02-13 21:05:51 +01:00
|
|
|
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
|
2025-02-13 21:16:46 +01:00
|
|
|
this=exp.cast_unless(
|
|
|
|
seq_get(args, 0) or exp.Var(this=""),
|
|
|
|
exp.DataType.build("timestamp"),
|
|
|
|
exp.DataType.build("timestamp"),
|
|
|
|
),
|
2025-02-13 21:05:51 +01:00
|
|
|
zone=seq_get(args, 1),
|
|
|
|
),
|
2025-02-13 20:46:23 +01:00
|
|
|
"IIF": exp.If.from_arg_list,
|
2025-02-13 15:52:54 +01:00
|
|
|
"INT": _parse_as_cast("int"),
|
2025-02-13 20:46:23 +01:00
|
|
|
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
|
|
|
"RLIKE": exp.RegexpLike.from_arg_list,
|
|
|
|
"SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
|
|
|
|
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
|
2025-02-13 15:52:54 +01:00
|
|
|
"STRING": _parse_as_cast("string"),
|
|
|
|
"TIMESTAMP": _parse_as_cast("timestamp"),
|
2025-02-13 20:46:23 +01:00
|
|
|
"TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
|
|
|
|
if len(args) == 1
|
|
|
|
else format_time_lambda(exp.StrToTime, "spark")(args),
|
|
|
|
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
|
|
|
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
|
|
|
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
2025-02-13 15:52:54 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
FUNCTION_PARSERS = {
|
2025-02-13 20:46:23 +01:00
|
|
|
**Hive.Parser.FUNCTION_PARSERS,
|
2025-02-13 15:52:54 +01:00
|
|
|
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
|
|
|
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
|
|
|
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
|
|
|
"MERGE": lambda self: self._parse_join_hint("MERGE"),
|
|
|
|
"SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
|
|
|
|
"MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
|
|
|
|
"SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
|
|
|
|
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
|
|
|
}
|
|
|
|
|
|
|
|
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
|
|
|
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
|
|
|
|
2025-02-13 16:00:14 +01:00
|
|
|
def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
|
2025-02-13 15:52:54 +01:00
|
|
|
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
2025-02-13 16:00:14 +01:00
|
|
|
exp.Drop, this=self._parse_schema(), kind="COLUMNS"
|
2025-02-13 15:52:54 +01:00
|
|
|
)
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
|
|
|
if len(aggregations) == 1:
|
2025-02-13 15:52:54 +01:00
|
|
|
return [""]
|
2025-02-13 15:56:32 +01:00
|
|
|
return pivot_column_names(aggregations, dialect="spark")
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
class Generator(Hive.Generator):
|
2025-02-13 20:42:40 +01:00
|
|
|
QUERY_HINTS = True
|
2025-02-13 20:56:33 +01:00
|
|
|
NVL2_SUPPORTED = True
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
PROPERTIES_LOCATION = {
|
2025-02-13 15:56:32 +01:00
|
|
|
**Hive.Generator.PROPERTIES_LOCATION,
|
2025-02-13 15:52:54 +01:00
|
|
|
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
|
|
|
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
|
|
|
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
|
|
|
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
|
|
|
}
|
|
|
|
|
|
|
|
TRANSFORMS = {
|
2025-02-13 15:56:32 +01:00
|
|
|
**Hive.Generator.TRANSFORMS,
|
2025-02-13 15:52:54 +01:00
|
|
|
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
|
|
|
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
|
|
|
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
|
|
|
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
|
|
|
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
|
|
|
exp.DateFromParts: rename_func("MAKE_DATE"),
|
|
|
|
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
|
|
|
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
|
|
|
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
|
|
|
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
|
|
|
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
2025-02-13 15:56:32 +01:00
|
|
|
exp.From: transforms.preprocess([_unalias_pivot]),
|
2025-02-13 15:52:54 +01:00
|
|
|
exp.LogicalAnd: rename_func("BOOL_AND"),
|
|
|
|
exp.LogicalOr: rename_func("BOOL_OR"),
|
|
|
|
exp.Map: _map_sql,
|
2025-02-13 15:56:32 +01:00
|
|
|
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
|
2025-02-13 15:52:54 +01:00
|
|
|
exp.Reduce: rename_func("AGGREGATE"),
|
2025-02-13 20:46:23 +01:00
|
|
|
exp.RegexpReplace: lambda self, e: self.func(
|
|
|
|
"REGEXP_REPLACE",
|
|
|
|
e.this,
|
|
|
|
e.expression,
|
|
|
|
e.args["replacement"],
|
|
|
|
e.args.get("position"),
|
|
|
|
),
|
2025-02-13 15:52:54 +01:00
|
|
|
exp.StrToDate: _str_to_date,
|
|
|
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
|
|
|
exp.TimestampTrunc: lambda self, e: self.func(
|
|
|
|
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
|
|
|
),
|
|
|
|
exp.Trim: trim_sql,
|
|
|
|
exp.UnixToTime: _unix_to_time_sql,
|
|
|
|
exp.VariancePop: rename_func("VAR_POP"),
|
|
|
|
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
|
|
|
exp.WithinGroup: transforms.preprocess(
|
|
|
|
[transforms.remove_within_group_for_percentiles]
|
|
|
|
),
|
|
|
|
}
|
|
|
|
TRANSFORMS.pop(exp.ArrayJoin)
|
|
|
|
TRANSFORMS.pop(exp.ArraySort)
|
|
|
|
TRANSFORMS.pop(exp.ILike)
|
2025-02-13 15:58:03 +01:00
|
|
|
TRANSFORMS.pop(exp.Left)
|
2025-02-13 20:46:23 +01:00
|
|
|
TRANSFORMS.pop(exp.MonthsBetween)
|
2025-02-13 15:58:03 +01:00
|
|
|
TRANSFORMS.pop(exp.Right)
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
WRAP_DERIVED_VALUES = False
|
|
|
|
CREATE_FUNCTION_RETURN_AS = False
|
|
|
|
|
2025-02-13 21:16:46 +01:00
|
|
|
def struct_sql(self, expression: exp.Struct) -> str:
|
|
|
|
args = []
|
|
|
|
for arg in expression.expressions:
|
2025-02-13 21:18:57 +01:00
|
|
|
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
|
2025-02-13 21:16:46 +01:00
|
|
|
if isinstance(arg, exp.Bracket):
|
|
|
|
args.append(exp.alias_(arg.this, arg.expressions[0].name))
|
|
|
|
else:
|
|
|
|
args.append(exp.alias_(arg.expression, arg.this.name))
|
|
|
|
else:
|
|
|
|
args.append(arg)
|
|
|
|
|
|
|
|
return self.func("STRUCT", *args)
|
|
|
|
|
2025-02-13 21:04:14 +01:00
|
|
|
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
|
|
|
# spark2, spark, Databricks require a storage provider for temporary tables
|
|
|
|
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
|
|
|
|
expression.args["properties"].append("expressions", provider)
|
|
|
|
return expression
|
|
|
|
|
2025-02-13 20:20:19 +01:00
|
|
|
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
2025-02-13 21:02:03 +01:00
|
|
|
if is_parse_json(expression.this):
|
2025-02-13 15:52:54 +01:00
|
|
|
schema = f"'{self.sql(expression, 'to')}'"
|
|
|
|
return self.func("FROM_JSON", expression.this.this, schema)
|
2025-02-13 21:02:03 +01:00
|
|
|
|
|
|
|
if is_parse_json(expression):
|
2025-02-13 15:52:54 +01:00
|
|
|
return self.func("TO_JSON", expression.this)
|
|
|
|
|
2025-02-13 20:20:19 +01:00
|
|
|
return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
|
2025-02-13 15:52:54 +01:00
|
|
|
|
|
|
|
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
|
|
|
return super().columndef_sql(
|
|
|
|
expression,
|
|
|
|
sep=": "
|
|
|
|
if isinstance(expression.parent, exp.DataType)
|
2025-02-13 15:58:03 +01:00
|
|
|
and expression.parent.is_type("struct")
|
2025-02-13 15:52:54 +01:00
|
|
|
else sep,
|
|
|
|
)
|
|
|
|
|
|
|
|
class Tokenizer(Hive.Tokenizer):
|
|
|
|
HEX_STRINGS = [("X'", "'")]
|