1
0
Fork 0
sqlglot/sqlglot/dialects/trino.py
Daniel Baumann 1763c7a4ef
Merging upstream version 25.24.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 21:55:40 +01:00

69 lines
2.4 KiB
Python

from __future__ import annotations
from sqlglot import exp, parser
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql, timestrtotime_sql
from sqlglot.dialects.presto import Presto
from sqlglot.tokens import TokenType
class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True
class Parser(Presto.Parser):
FUNCTION_PARSERS = {
**Presto.Parser.FUNCTION_PARSERS,
"TRIM": lambda self: self._parse_trim(),
"JSON_QUERY": lambda self: self._parse_json_query(),
}
JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
**dict.fromkeys(
("WITH", "WITHOUT"),
(
("CONDITIONAL", "WRAPPER"),
("CONDITIONAL", "ARRAY", "WRAPPED"),
("UNCONDITIONAL", "WRAPPER"),
("UNCONDITIONAL", "ARRAY", "WRAPPER"),
),
),
}
def _parse_json_query(self):
return self.expression(
exp.JSONExtract,
this=self._parse_bitwise(),
expression=self._match(TokenType.COMMA) and self._parse_bitwise(),
option=self._parse_var_from_options(self.JSON_QUERY_OPTIONS, raise_unmatched=False),
json_query=True,
)
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self,
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
exp.Trim: trim_sql,
exp.JSONExtract: lambda self, e: self.jsonextract_sql(e),
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
def jsonextract_sql(self, expression: exp.JSONExtract) -> str:
if not expression.args.get("json_query"):
return super().jsonextract_sql(expression)
json_path = self.sql(expression, "expression")
option = self.sql(expression, "option")
option = f" {option}" if option else ""
return self.func("JSON_QUERY", expression.this, json_path + option)
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]