2025-02-13 14:52:26 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-02-13 15:51:35 +01:00
|
|
|
import typing as t
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
from sqlglot import exp
|
2025-02-13 21:30:02 +01:00
|
|
|
from sqlglot.dialects.dialect import rename_func, unit_to_var
|
2025-02-13 21:28:14 +01:00
|
|
|
from sqlglot.dialects.hive import _build_with_ignore_nulls
|
2025-02-13 21:33:03 +01:00
|
|
|
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
|
2025-02-13 21:31:00 +01:00
|
|
|
from sqlglot.helper import ensure_list, seq_get
|
2025-02-13 21:25:55 +01:00
|
|
|
from sqlglot.transforms import (
|
|
|
|
ctas_with_tmp_tables_to_create_tmp_view,
|
|
|
|
remove_unique_constraints,
|
|
|
|
preprocess,
|
|
|
|
move_partitioned_by_to_schema_columns,
|
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 21:28:14 +01:00
|
|
|
def _build_datediff(args: t.List) -> exp.Expression:
|
2025-02-13 15:52:54 +01:00
|
|
|
"""
|
|
|
|
Although Spark docs don't mention the "unit" argument, Spark3 added support for
|
2025-02-13 15:56:32 +01:00
|
|
|
it at some point. Databricks also supports this variant (see below).
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
For example, in spark-sql (v3.3.1):
|
|
|
|
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
|
|
|
|
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
See also:
|
|
|
|
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
|
|
|
|
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
|
|
|
|
"""
|
|
|
|
unit = None
|
|
|
|
this = seq_get(args, 0)
|
|
|
|
expression = seq_get(args, 1)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
if len(args) == 3:
|
|
|
|
unit = this
|
|
|
|
this = args[2]
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
return exp.DateDiff(
|
|
|
|
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
|
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 21:38:56 +01:00
|
|
|
def _build_dateadd(args: t.List) -> exp.Expression:
|
|
|
|
expression = seq_get(args, 1)
|
|
|
|
|
|
|
|
if len(args) == 2:
|
|
|
|
# DATE_ADD(startDate, numDays INTEGER)
|
|
|
|
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
|
|
|
|
return exp.TsOrDsAdd(
|
|
|
|
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
|
|
|
|
)
|
|
|
|
|
|
|
|
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
|
|
|
|
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
|
|
|
|
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
|
|
|
|
|
|
|
|
|
2025-02-13 21:25:55 +01:00
|
|
|
def _normalize_partition(e: exp.Expression) -> exp.Expression:
|
|
|
|
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
|
|
|
|
if isinstance(e, str):
|
|
|
|
return exp.to_identifier(e)
|
|
|
|
if isinstance(e, exp.Literal):
|
|
|
|
return exp.to_identifier(e.name)
|
|
|
|
return e
|
|
|
|
|
|
|
|
|
2025-02-13 21:38:56 +01:00
|
|
|
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
|
|
|
|
if not expression.unit or (
|
|
|
|
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
|
|
|
|
):
|
|
|
|
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
|
|
|
|
return self.func("DATE_ADD", expression.this, expression.expression)
|
|
|
|
|
|
|
|
this = self.func(
|
|
|
|
"DATE_ADD",
|
|
|
|
unit_to_var(expression),
|
|
|
|
expression.expression,
|
|
|
|
expression.this,
|
|
|
|
)
|
|
|
|
|
|
|
|
if isinstance(expression, exp.TsOrDsAdd):
|
|
|
|
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
|
|
|
|
# in other dialects
|
|
|
|
return_type = expression.return_type
|
|
|
|
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
|
|
|
|
this = f"CAST({this} AS {return_type})"
|
|
|
|
|
|
|
|
return this
|
|
|
|
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
class Spark(Spark2):
|
2025-02-13 20:59:23 +01:00
|
|
|
class Tokenizer(Spark2.Tokenizer):
|
|
|
|
RAW_STRINGS = [
|
|
|
|
(prefix + q, q)
|
|
|
|
for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
|
|
|
|
for prefix in ("r", "R")
|
|
|
|
]
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
class Parser(Spark2.Parser):
|
2025-02-13 06:15:54 +01:00
|
|
|
FUNCTIONS = {
|
2025-02-13 15:56:32 +01:00
|
|
|
**Spark2.Parser.FUNCTIONS,
|
2025-02-13 21:28:14 +01:00
|
|
|
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
|
2025-02-13 21:38:56 +01:00
|
|
|
"DATE_ADD": _build_dateadd,
|
|
|
|
"DATEADD": _build_dateadd,
|
|
|
|
"TIMESTAMPADD": _build_dateadd,
|
2025-02-13 21:28:14 +01:00
|
|
|
"DATEDIFF": _build_datediff,
|
2025-02-13 21:33:03 +01:00
|
|
|
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
|
|
|
|
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
|
2025-02-13 21:31:00 +01:00
|
|
|
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
|
|
|
|
this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
|
|
|
|
),
|
2025-02-13 06:15:54 +01:00
|
|
|
}
|
|
|
|
|
2025-02-13 21:04:14 +01:00
|
|
|
def _parse_generated_as_identity(
|
|
|
|
self,
|
2025-02-13 21:16:46 +01:00
|
|
|
) -> (
|
|
|
|
exp.GeneratedAsIdentityColumnConstraint
|
|
|
|
| exp.ComputedColumnConstraint
|
|
|
|
| exp.GeneratedAsRowColumnConstraint
|
|
|
|
):
|
2025-02-13 21:04:14 +01:00
|
|
|
this = super()._parse_generated_as_identity()
|
|
|
|
if this.expression:
|
|
|
|
return self.expression(exp.ComputedColumnConstraint, this=this.expression)
|
|
|
|
return this
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
class Generator(Spark2.Generator):
|
2025-02-13 21:30:02 +01:00
|
|
|
SUPPORTS_TO_NUMBER = True
|
|
|
|
|
2025-02-13 20:43:42 +01:00
|
|
|
TYPE_MAPPING = {
|
|
|
|
**Spark2.Generator.TYPE_MAPPING,
|
|
|
|
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
|
|
|
|
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
|
|
|
|
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
|
2025-02-13 21:33:03 +01:00
|
|
|
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
|
|
|
|
exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
|
2025-02-13 20:43:42 +01:00
|
|
|
}
|
2025-02-13 20:47:46 +01:00
|
|
|
|
|
|
|
TRANSFORMS = {
|
|
|
|
**Spark2.Generator.TRANSFORMS,
|
2025-02-13 21:34:56 +01:00
|
|
|
exp.ArrayConstructCompact: lambda self, e: self.func(
|
|
|
|
"ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
|
|
|
|
),
|
2025-02-13 21:25:55 +01:00
|
|
|
exp.Create: preprocess(
|
|
|
|
[
|
|
|
|
remove_unique_constraints,
|
|
|
|
lambda e: ctas_with_tmp_tables_to_create_tmp_view(
|
|
|
|
e, temporary_storage_provider
|
|
|
|
),
|
|
|
|
move_partitioned_by_to_schema_columns,
|
|
|
|
]
|
|
|
|
),
|
|
|
|
exp.PartitionedByProperty: lambda self,
|
|
|
|
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
|
2025-02-13 20:47:46 +01:00
|
|
|
exp.StartsWith: rename_func("STARTSWITH"),
|
2025-02-13 21:38:56 +01:00
|
|
|
exp.TsOrDsAdd: _dateadd_sql,
|
|
|
|
exp.TimestampAdd: _dateadd_sql,
|
2025-02-13 21:19:36 +01:00
|
|
|
exp.TryCast: lambda self, e: (
|
|
|
|
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
|
|
|
|
),
|
2025-02-13 20:47:46 +01:00
|
|
|
}
|
2025-02-13 20:56:33 +01:00
|
|
|
TRANSFORMS.pop(exp.AnyValue)
|
2025-02-13 15:52:54 +01:00
|
|
|
TRANSFORMS.pop(exp.DateDiff)
|
2025-02-13 20:20:19 +01:00
|
|
|
TRANSFORMS.pop(exp.Group)
|
2025-02-13 14:44:19 +01:00
|
|
|
|
2025-02-13 21:31:00 +01:00
|
|
|
def bracket_sql(self, expression: exp.Bracket) -> str:
|
|
|
|
if expression.args.get("safe"):
|
|
|
|
key = seq_get(self.bracket_offset_expressions(expression), 0)
|
|
|
|
return self.func("TRY_ELEMENT_AT", expression.this, key)
|
|
|
|
|
|
|
|
return super().bracket_sql(expression)
|
|
|
|
|
2025-02-13 21:04:14 +01:00
|
|
|
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
|
|
|
|
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
|
|
|
|
|
2025-02-13 20:56:33 +01:00
|
|
|
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
|
|
|
|
return self.function_fallback_sql(expression)
|
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
|
|
|
end = self.sql(expression, "this")
|
|
|
|
start = self.sql(expression, "expression")
|
2025-02-13 15:09:11 +01:00
|
|
|
|
2025-02-13 21:30:02 +01:00
|
|
|
if expression.unit:
|
|
|
|
return self.func("DATEDIFF", unit_to_var(expression), start, end)
|
2025-02-13 15:09:11 +01:00
|
|
|
|
2025-02-13 15:52:54 +01:00
|
|
|
return self.func("DATEDIFF", end, start)
|