2025-02-13 14:53:05 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-02-13 21:55:40 +01:00
|
|
|
from sqlglot import exp, parser
|
2025-02-13 21:52:55 +01:00
|
|
|
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql, timestrtotime_sql
|
2025-02-13 06:15:54 +01:00
|
|
|
from sqlglot.dialects.presto import Presto
|
2025-02-13 21:55:40 +01:00
|
|
|
from sqlglot.tokens import TokenType
|
2025-02-13 22:00:29 +01:00
|
|
|
import typing as t
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
|
|
|
class Trino(Presto):
|
2025-02-13 21:02:36 +01:00
|
|
|
SUPPORTS_USER_DEFINED_TYPES = False
|
2025-02-13 21:30:28 +01:00
|
|
|
LOG_BASE_FIRST = True
|
2025-02-13 21:02:36 +01:00
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
class Tokenizer(Presto.Tokenizer):
|
|
|
|
HEX_STRINGS = [("X'", "'")]
|
|
|
|
|
2025-02-13 21:33:25 +01:00
|
|
|
class Parser(Presto.Parser):
|
|
|
|
FUNCTION_PARSERS = {
|
|
|
|
**Presto.Parser.FUNCTION_PARSERS,
|
|
|
|
"TRIM": lambda self: self._parse_trim(),
|
2025-02-13 21:55:40 +01:00
|
|
|
"JSON_QUERY": lambda self: self._parse_json_query(),
|
2025-02-13 21:56:02 +01:00
|
|
|
"LISTAGG": lambda self: self._parse_string_agg(),
|
2025-02-13 21:33:25 +01:00
|
|
|
}
|
|
|
|
|
2025-02-13 21:55:40 +01:00
|
|
|
JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
|
|
|
|
**dict.fromkeys(
|
|
|
|
("WITH", "WITHOUT"),
|
|
|
|
(
|
2025-02-13 22:19:49 +01:00
|
|
|
("WRAPPER"),
|
|
|
|
("ARRAY", "WRAPPER"),
|
2025-02-13 21:55:40 +01:00
|
|
|
("CONDITIONAL", "WRAPPER"),
|
|
|
|
("CONDITIONAL", "ARRAY", "WRAPPED"),
|
|
|
|
("UNCONDITIONAL", "WRAPPER"),
|
|
|
|
("UNCONDITIONAL", "ARRAY", "WRAPPER"),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
}
|
|
|
|
|
2025-02-13 22:00:29 +01:00
|
|
|
def _parse_json_query_quote(self) -> t.Optional[exp.JSONExtractQuote]:
|
|
|
|
if not (
|
|
|
|
self._match_text_seq("KEEP", "QUOTES") or self._match_text_seq("OMIT", "QUOTES")
|
|
|
|
):
|
|
|
|
return None
|
|
|
|
|
|
|
|
return self.expression(
|
|
|
|
exp.JSONExtractQuote,
|
|
|
|
option=self._tokens[self._index - 2].text.upper(),
|
|
|
|
scalar=self._match_text_seq("ON", "SCALAR", "STRING"),
|
|
|
|
)
|
|
|
|
|
|
|
|
def _parse_json_query(self) -> exp.JSONExtract:
|
2025-02-13 21:55:40 +01:00
|
|
|
return self.expression(
|
|
|
|
exp.JSONExtract,
|
|
|
|
this=self._parse_bitwise(),
|
|
|
|
expression=self._match(TokenType.COMMA) and self._parse_bitwise(),
|
|
|
|
option=self._parse_var_from_options(self.JSON_QUERY_OPTIONS, raise_unmatched=False),
|
|
|
|
json_query=True,
|
2025-02-13 22:00:29 +01:00
|
|
|
quote=self._parse_json_query_quote(),
|
2025-02-13 21:55:40 +01:00
|
|
|
)
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
class Generator(Presto.Generator):
|
2025-02-13 22:00:53 +01:00
|
|
|
PROPERTIES_LOCATION = {
|
|
|
|
**Presto.Generator.PROPERTIES_LOCATION,
|
|
|
|
exp.LocationProperty: exp.Properties.Location.POST_WITH,
|
|
|
|
}
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
TRANSFORMS = {
|
2025-02-13 15:57:23 +01:00
|
|
|
**Presto.Generator.TRANSFORMS,
|
2025-02-13 21:20:36 +01:00
|
|
|
exp.ArraySum: lambda self,
|
|
|
|
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
2025-02-13 22:00:29 +01:00
|
|
|
exp.ArrayUniqueAgg: lambda self, e: f"ARRAY_AGG(DISTINCT {self.sql(e, 'this')})",
|
2025-02-13 22:00:53 +01:00
|
|
|
exp.LocationProperty: lambda self, e: self.property_sql(e),
|
2025-02-13 21:19:14 +01:00
|
|
|
exp.Merge: merge_without_target_sql,
|
2025-02-13 21:52:55 +01:00
|
|
|
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
|
2025-02-13 21:33:25 +01:00
|
|
|
exp.Trim: trim_sql,
|
2025-02-13 21:55:40 +01:00
|
|
|
exp.JSONExtract: lambda self, e: self.jsonextract_sql(e),
|
2025-02-13 06:15:54 +01:00
|
|
|
}
|
2025-02-13 08:04:41 +01:00
|
|
|
|
2025-02-13 21:20:36 +01:00
|
|
|
SUPPORTED_JSON_PATH_PARTS = {
|
|
|
|
exp.JSONPathKey,
|
|
|
|
exp.JSONPathRoot,
|
|
|
|
exp.JSONPathSubscript,
|
|
|
|
}
|
|
|
|
|
2025-02-13 21:55:40 +01:00
|
|
|
def jsonextract_sql(self, expression: exp.JSONExtract) -> str:
|
|
|
|
if not expression.args.get("json_query"):
|
|
|
|
return super().jsonextract_sql(expression)
|
|
|
|
|
|
|
|
json_path = self.sql(expression, "expression")
|
|
|
|
option = self.sql(expression, "option")
|
|
|
|
option = f" {option}" if option else ""
|
|
|
|
|
2025-02-13 22:00:29 +01:00
|
|
|
quote = self.sql(expression, "quote")
|
|
|
|
quote = f" {quote}" if quote else ""
|
|
|
|
|
|
|
|
return self.func("JSON_QUERY", expression.this, json_path + option + quote)
|
2025-02-13 21:55:40 +01:00
|
|
|
|
2025-02-13 21:56:02 +01:00
|
|
|
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
|
|
|
|
this = expression.this
|
|
|
|
separator = expression.args.get("separator") or exp.Literal.string(",")
|
|
|
|
|
|
|
|
if isinstance(this, exp.Order):
|
|
|
|
if this.this:
|
|
|
|
this = this.this.pop()
|
|
|
|
|
2025-02-13 21:56:19 +01:00
|
|
|
on_overflow = self.sql(expression, "on_overflow")
|
|
|
|
on_overflow = f" ON OVERFLOW {on_overflow}" if on_overflow else ""
|
|
|
|
return f"LISTAGG({self.format_args(this, separator)}{on_overflow}) WITHIN GROUP ({self.sql(expression.this).lstrip()})"
|
2025-02-13 21:56:02 +01:00
|
|
|
|
|
|
|
return super().groupconcat_sql(expression)
|