2025-02-13 14:52:26 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing as t
|
2025-02-13 06:15:54 +01:00
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
from sqlglot import exp
|
|
|
|
from sqlglot.generator import Generator
|
2025-02-13 14:52:26 +01:00
|
|
|
from sqlglot.helper import flatten, seq_get
|
2025-02-13 06:15:54 +01:00
|
|
|
from sqlglot.parser import Parser
|
|
|
|
from sqlglot.time import format_time
|
2025-02-13 15:56:32 +01:00
|
|
|
from sqlglot.tokens import Token, Tokenizer, TokenType
|
2025-02-13 06:15:54 +01:00
|
|
|
from sqlglot.trie import new_trie
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
if t.TYPE_CHECKING:
|
|
|
|
from sqlglot._typing import E
|
|
|
|
|
|
|
|
|
|
|
|
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
|
|
|
|
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
|
|
|
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
|
2025-02-13 15:22:50 +01:00
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
class Dialects(str, Enum):
|
|
|
|
DIALECT = ""
|
|
|
|
|
|
|
|
BIGQUERY = "bigquery"
|
|
|
|
CLICKHOUSE = "clickhouse"
|
|
|
|
DUCKDB = "duckdb"
|
|
|
|
HIVE = "hive"
|
|
|
|
MYSQL = "mysql"
|
|
|
|
ORACLE = "oracle"
|
|
|
|
POSTGRES = "postgres"
|
|
|
|
PRESTO = "presto"
|
2025-02-13 07:47:22 +01:00
|
|
|
REDSHIFT = "redshift"
|
2025-02-13 06:15:54 +01:00
|
|
|
SNOWFLAKE = "snowflake"
|
|
|
|
SPARK = "spark"
|
2025-02-13 15:52:54 +01:00
|
|
|
SPARK2 = "spark2"
|
2025-02-13 06:15:54 +01:00
|
|
|
SQLITE = "sqlite"
|
|
|
|
STARROCKS = "starrocks"
|
|
|
|
TABLEAU = "tableau"
|
|
|
|
TRINO = "trino"
|
2025-02-13 14:30:50 +01:00
|
|
|
TSQL = "tsql"
|
2025-02-13 14:51:09 +01:00
|
|
|
DATABRICKS = "databricks"
|
2025-02-13 14:53:43 +01:00
|
|
|
DRILL = "drill"
|
2025-02-13 15:04:17 +01:00
|
|
|
TERADATA = "teradata"
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
|
|
|
class _Dialect(type):
|
2025-02-13 15:22:50 +01:00
|
|
|
classes: t.Dict[str, t.Type[Dialect]] = {}
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def __eq__(cls, other: t.Any) -> bool:
|
|
|
|
if cls is other:
|
|
|
|
return True
|
|
|
|
if isinstance(other, str):
|
|
|
|
return cls is cls.get(other)
|
|
|
|
if isinstance(other, Dialect):
|
|
|
|
return cls is type(other)
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
def __hash__(cls) -> int:
|
|
|
|
return hash(cls.__name__.lower())
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
@classmethod
|
2025-02-13 15:22:50 +01:00
|
|
|
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
2025-02-13 06:15:54 +01:00
|
|
|
return cls.classes[key]
|
|
|
|
|
|
|
|
@classmethod
|
2025-02-13 15:22:50 +01:00
|
|
|
def get(
|
|
|
|
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
|
|
|
|
) -> t.Optional[t.Type[Dialect]]:
|
2025-02-13 06:15:54 +01:00
|
|
|
return cls.classes.get(key, default)
|
|
|
|
|
|
|
|
def __new__(cls, clsname, bases, attrs):
|
|
|
|
klass = super().__new__(cls, clsname, bases, attrs)
|
|
|
|
enum = Dialects.__members__.get(clsname.upper())
|
|
|
|
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
|
|
|
|
|
|
|
|
klass.time_trie = new_trie(klass.time_mapping)
|
|
|
|
klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
|
|
|
|
klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
|
|
|
|
|
|
|
|
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
|
|
|
|
klass.parser_class = getattr(klass, "Parser", Parser)
|
|
|
|
klass.generator_class = getattr(klass, "Generator", Generator)
|
|
|
|
|
2025-02-13 07:47:22 +01:00
|
|
|
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
2025-02-13 14:52:26 +01:00
|
|
|
klass.identifier_start, klass.identifier_end = list(
|
|
|
|
klass.tokenizer_class._IDENTIFIERS.items()
|
|
|
|
)[0]
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
|
|
|
|
return next(
|
|
|
|
(
|
|
|
|
(s, e)
|
|
|
|
for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
|
|
|
|
if t == token_type
|
|
|
|
),
|
|
|
|
(None, None),
|
|
|
|
)
|
2025-02-13 15:52:54 +01:00
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
|
|
|
|
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
|
|
|
|
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
|
|
|
|
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:58:03 +01:00
|
|
|
klass.tokenizer_class.identifiers_can_start_with_digit = (
|
|
|
|
klass.identifiers_can_start_with_digit
|
|
|
|
)
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
return klass
|
|
|
|
|
|
|
|
|
|
|
|
class Dialect(metaclass=_Dialect):
|
|
|
|
index_offset = 0
|
|
|
|
unnest_column_only = False
|
|
|
|
alias_post_tablesample = False
|
2025-02-13 15:58:03 +01:00
|
|
|
identifiers_can_start_with_digit = False
|
2025-02-13 14:52:26 +01:00
|
|
|
normalize_functions: t.Optional[str] = "upper"
|
2025-02-13 06:15:54 +01:00
|
|
|
null_ordering = "nulls_are_small"
|
|
|
|
|
|
|
|
date_format = "'%Y-%m-%d'"
|
|
|
|
dateint_format = "'%Y%m%d'"
|
|
|
|
time_format = "'%Y-%m-%d %H:%M:%S'"
|
2025-02-13 14:52:26 +01:00
|
|
|
time_mapping: t.Dict[str, str] = {}
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
# autofilled
|
|
|
|
quote_start = None
|
|
|
|
quote_end = None
|
|
|
|
identifier_start = None
|
|
|
|
identifier_end = None
|
|
|
|
|
|
|
|
time_trie = None
|
|
|
|
inverse_time_mapping = None
|
|
|
|
inverse_time_trie = None
|
|
|
|
tokenizer_class = None
|
|
|
|
parser_class = None
|
|
|
|
generator_class = None
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def __eq__(self, other: t.Any) -> bool:
|
|
|
|
return type(self) == other
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash(type(self))
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
@classmethod
|
2025-02-13 15:22:50 +01:00
|
|
|
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
2025-02-13 06:15:54 +01:00
|
|
|
if not dialect:
|
|
|
|
return cls
|
2025-02-13 15:09:11 +01:00
|
|
|
if isinstance(dialect, _Dialect):
|
|
|
|
return dialect
|
|
|
|
if isinstance(dialect, Dialect):
|
|
|
|
return dialect.__class__
|
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
result = cls.get(dialect)
|
|
|
|
if not result:
|
|
|
|
raise ValueError(f"Unknown dialect '{dialect}'")
|
2025-02-13 15:09:11 +01:00
|
|
|
|
2025-02-13 06:15:54 +01:00
|
|
|
return result
|
|
|
|
|
|
|
|
@classmethod
|
2025-02-13 15:22:50 +01:00
|
|
|
def format_time(
|
|
|
|
cls, expression: t.Optional[str | exp.Expression]
|
|
|
|
) -> t.Optional[exp.Expression]:
|
2025-02-13 06:15:54 +01:00
|
|
|
if isinstance(expression, str):
|
|
|
|
return exp.Literal.string(
|
|
|
|
format_time(
|
|
|
|
expression[1:-1], # the time formats are quoted
|
|
|
|
cls.time_mapping,
|
|
|
|
cls.time_trie,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if expression and expression.is_string:
|
|
|
|
return exp.Literal.string(
|
|
|
|
format_time(
|
|
|
|
expression.this,
|
|
|
|
cls.time_mapping,
|
|
|
|
cls.time_trie,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return expression
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
2025-02-13 15:45:33 +01:00
|
|
|
return self.parser(**opts).parse(self.tokenize(sql), sql)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def parse_into(
|
|
|
|
self, expression_type: exp.IntoType, sql: str, **opts
|
|
|
|
) -> t.List[t.Optional[exp.Expression]]:
|
2025-02-13 15:45:33 +01:00
|
|
|
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
return self.generator(**opts).generate(expression)
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def transpile(self, sql: str, **opts) -> t.List[str]:
|
|
|
|
return [self.generate(expression, **opts) for expression in self.parse(sql)]
|
2025-02-13 06:15:54 +01:00
|
|
|
|
2025-02-13 15:45:33 +01:00
|
|
|
def tokenize(self, sql: str) -> t.List[Token]:
|
|
|
|
return self.tokenizer.tokenize(sql)
|
|
|
|
|
2025-02-13 14:30:50 +01:00
|
|
|
@property
|
2025-02-13 15:22:50 +01:00
|
|
|
def tokenizer(self) -> Tokenizer:
|
2025-02-13 14:30:50 +01:00
|
|
|
if not hasattr(self, "_tokenizer"):
|
2025-02-13 15:22:50 +01:00
|
|
|
self._tokenizer = self.tokenizer_class() # type: ignore
|
2025-02-13 14:30:50 +01:00
|
|
|
return self._tokenizer
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def parser(self, **opts) -> Parser:
|
|
|
|
return self.parser_class( # type: ignore
|
2025-02-13 06:15:54 +01:00
|
|
|
**{
|
|
|
|
"index_offset": self.index_offset,
|
|
|
|
"unnest_column_only": self.unnest_column_only,
|
|
|
|
"alias_post_tablesample": self.alias_post_tablesample,
|
|
|
|
"null_ordering": self.null_ordering,
|
|
|
|
**opts,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def generator(self, **opts) -> Generator:
|
|
|
|
return self.generator_class( # type: ignore
|
2025-02-13 06:15:54 +01:00
|
|
|
**{
|
|
|
|
"quote_start": self.quote_start,
|
|
|
|
"quote_end": self.quote_end,
|
2025-02-13 15:52:54 +01:00
|
|
|
"bit_start": self.bit_start,
|
|
|
|
"bit_end": self.bit_end,
|
|
|
|
"hex_start": self.hex_start,
|
|
|
|
"hex_end": self.hex_end,
|
|
|
|
"byte_start": self.byte_start,
|
|
|
|
"byte_end": self.byte_end,
|
2025-02-13 15:56:32 +01:00
|
|
|
"raw_start": self.raw_start,
|
|
|
|
"raw_end": self.raw_end,
|
2025-02-13 06:15:54 +01:00
|
|
|
"identifier_start": self.identifier_start,
|
|
|
|
"identifier_end": self.identifier_end,
|
2025-02-13 15:22:50 +01:00
|
|
|
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
|
|
|
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
2025-02-13 06:15:54 +01:00
|
|
|
"index_offset": self.index_offset,
|
|
|
|
"time_mapping": self.inverse_time_mapping,
|
|
|
|
"time_trie": self.inverse_time_trie,
|
|
|
|
"unnest_column_only": self.unnest_column_only,
|
|
|
|
"alias_post_tablesample": self.alias_post_tablesample,
|
2025-02-13 15:58:03 +01:00
|
|
|
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
|
2025-02-13 06:15:54 +01:00
|
|
|
"normalize_functions": self.normalize_functions,
|
|
|
|
"null_ordering": self.null_ordering,
|
|
|
|
**opts,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
2025-02-13 15:09:11 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
|
2025-02-13 15:24:45 +01:00
|
|
|
return lambda self, expression: self.func(name, *flatten(expression.args.values()))
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
if expression.args.get("accuracy"):
|
|
|
|
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
2025-02-13 15:24:45 +01:00
|
|
|
return self.func("APPROX_COUNT_DISTINCT", expression.this)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def if_sql(self: Generator, expression: exp.If) -> str:
|
2025-02-13 15:24:45 +01:00
|
|
|
return self.func(
|
|
|
|
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
|
2025-02-13 14:52:26 +01:00
|
|
|
)
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
2025-02-13 15:01:11 +01:00
|
|
|
return self.binary(expression, "->")
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def arrow_json_extract_scalar_sql(
|
|
|
|
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
|
|
|
|
) -> str:
|
2025-02-13 15:01:11 +01:00
|
|
|
return self.binary(expression, "->>")
|
2025-02-13 06:15:54 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
return f"[{self.expressions(expression)}]"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
return self.like_sql(
|
|
|
|
exp.Like(
|
|
|
|
this=exp.Lower(this=expression.this),
|
|
|
|
expression=expression.args["expression"],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
zone = self.sql(expression, "this")
|
|
|
|
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
if expression.args.get("recursive"):
|
|
|
|
self.unsupported("Recursive CTEs are unsupported")
|
|
|
|
expression.args["recursive"] = False
|
|
|
|
return self.with_sql(expression)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
n = self.sql(expression, "this")
|
|
|
|
d = self.sql(expression, "expression")
|
|
|
|
return f"IF({d} <> 0, {n} / {d}, NULL)"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
self.unsupported("TABLESAMPLE unsupported")
|
|
|
|
return self.sql(expression.this)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
|
2025-02-13 14:37:25 +01:00
|
|
|
self.unsupported("PIVOT unsupported")
|
2025-02-13 15:56:32 +01:00
|
|
|
return ""
|
2025-02-13 14:37:25 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
return self.cast_sql(expression)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
2025-02-13 14:46:14 +01:00
|
|
|
self.unsupported("Properties unsupported")
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
2025-02-13 15:51:35 +01:00
|
|
|
def no_comment_column_constraint_sql(
|
|
|
|
self: Generator, expression: exp.CommentColumnConstraint
|
|
|
|
) -> str:
|
|
|
|
self.unsupported("CommentColumnConstraint unsupported")
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
this = self.sql(expression, "this")
|
|
|
|
substr = self.sql(expression, "substr")
|
|
|
|
position = self.sql(expression, "position")
|
|
|
|
if position:
|
|
|
|
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
|
|
|
|
return f"STRPOS({this}, {substr})"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
2025-02-13 06:15:54 +01:00
|
|
|
this = self.sql(expression, "this")
|
|
|
|
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
|
|
|
|
return f"{this}.{struct_key}"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def var_map_sql(
|
|
|
|
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
|
|
|
|
) -> str:
|
2025-02-13 14:44:19 +01:00
|
|
|
keys = expression.args["keys"]
|
|
|
|
values = expression.args["values"]
|
|
|
|
|
|
|
|
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
|
|
|
|
self.unsupported("Cannot convert array columns into map.")
|
2025-02-13 15:24:45 +01:00
|
|
|
return self.func(map_func_name, keys, values)
|
2025-02-13 14:44:19 +01:00
|
|
|
|
|
|
|
args = []
|
|
|
|
for key, value in zip(keys.expressions, values.expressions):
|
|
|
|
args.append(self.sql(key))
|
|
|
|
args.append(self.sql(value))
|
2025-02-13 15:24:45 +01:00
|
|
|
return self.func(map_func_name, *args)
|
2025-02-13 14:44:19 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def format_time_lambda(
|
|
|
|
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
2025-02-13 15:56:32 +01:00
|
|
|
) -> t.Callable[[t.List], E]:
|
2025-02-13 06:15:54 +01:00
|
|
|
"""Helper used for time expressions.
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
Args:
|
|
|
|
exp_class: the expression class to instantiate.
|
|
|
|
dialect: target sql dialect.
|
|
|
|
default: the default format, True being time.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A callable that can be used to return the appropriately formatted time expression.
|
2025-02-13 06:15:54 +01:00
|
|
|
"""
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def _format_time(args: t.List):
|
2025-02-13 06:15:54 +01:00
|
|
|
return exp_class(
|
2025-02-13 14:52:26 +01:00
|
|
|
this=seq_get(args, 0),
|
2025-02-13 06:15:54 +01:00
|
|
|
format=Dialect[dialect].format_time(
|
2025-02-13 15:22:50 +01:00
|
|
|
seq_get(args, 1)
|
|
|
|
or (Dialect[dialect].time_format if default is True else default or None)
|
2025-02-13 06:15:54 +01:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
return _format_time
|
2025-02-13 14:37:25 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
2025-02-13 14:37:25 +01:00
|
|
|
"""
|
|
|
|
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
|
|
|
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
|
|
|
|
columns are removed from the create statement.
|
|
|
|
"""
|
|
|
|
has_schema = isinstance(expression.this, exp.Schema)
|
|
|
|
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
|
|
|
|
|
|
|
|
if has_schema and is_partitionable:
|
|
|
|
expression = expression.copy()
|
|
|
|
prop = expression.find(exp.PartitionedByProperty)
|
2025-02-13 15:41:13 +01:00
|
|
|
if prop and prop.this and not isinstance(prop.this, exp.Schema):
|
2025-02-13 14:37:25 +01:00
|
|
|
schema = expression.this
|
2025-02-13 15:41:13 +01:00
|
|
|
columns = {v.name.upper() for v in prop.this.expressions}
|
2025-02-13 14:37:25 +01:00
|
|
|
partitions = [col for col in schema.expressions if col.name.upper() in columns]
|
2025-02-13 14:55:11 +01:00
|
|
|
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
|
|
|
|
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
|
2025-02-13 14:37:25 +01:00
|
|
|
expression.set("this", schema)
|
|
|
|
|
|
|
|
return self.create_sql(expression)
|
2025-02-13 14:51:09 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def parse_date_delta(
|
|
|
|
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
2025-02-13 15:56:32 +01:00
|
|
|
) -> t.Callable[[t.List], E]:
|
|
|
|
def inner_func(args: t.List) -> E:
|
2025-02-13 14:51:09 +01:00
|
|
|
unit_based = len(args) == 3
|
2025-02-13 15:51:35 +01:00
|
|
|
this = args[2] if unit_based else seq_get(args, 0)
|
|
|
|
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
2025-02-13 15:56:32 +01:00
|
|
|
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
|
2025-02-13 15:51:35 +01:00
|
|
|
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
2025-02-13 14:51:09 +01:00
|
|
|
|
|
|
|
return inner_func
|
2025-02-13 14:53:43 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:51:35 +01:00
|
|
|
def parse_date_delta_with_interval(
|
|
|
|
expression_class: t.Type[E],
|
2025-02-13 15:56:32 +01:00
|
|
|
) -> t.Callable[[t.List], t.Optional[E]]:
|
|
|
|
def func(args: t.List) -> t.Optional[E]:
|
2025-02-13 15:51:35 +01:00
|
|
|
if len(args) < 2:
|
|
|
|
return None
|
|
|
|
|
|
|
|
interval = args[1]
|
|
|
|
expression = interval.this
|
|
|
|
if expression and expression.is_string:
|
|
|
|
expression = exp.Literal.number(expression.this)
|
|
|
|
|
|
|
|
return expression_class(
|
|
|
|
this=args[0],
|
|
|
|
expression=expression,
|
|
|
|
unit=exp.Literal.string(interval.text("unit")),
|
|
|
|
)
|
|
|
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
2025-02-13 15:45:33 +01:00
|
|
|
unit = seq_get(args, 0)
|
|
|
|
this = seq_get(args, 1)
|
|
|
|
|
2025-02-13 15:58:03 +01:00
|
|
|
if isinstance(this, exp.Cast) and this.is_type("date"):
|
2025-02-13 15:45:33 +01:00
|
|
|
return exp.DateTrunc(unit=unit, this=this)
|
|
|
|
return exp.TimestampTrunc(this=this, unit=unit)
|
|
|
|
|
|
|
|
|
|
|
|
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
|
|
|
return self.func(
|
|
|
|
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def locate_to_strposition(args: t.List) -> exp.Expression:
|
2025-02-13 14:53:43 +01:00
|
|
|
return exp.StrPosition(
|
|
|
|
this=seq_get(args, 1),
|
|
|
|
substr=seq_get(args, 0),
|
|
|
|
position=seq_get(args, 2),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
|
2025-02-13 15:24:45 +01:00
|
|
|
return self.func(
|
|
|
|
"LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
|
2025-02-13 14:53:43 +01:00
|
|
|
)
|
2025-02-13 15:01:11 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:58:03 +01:00
|
|
|
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|
|
|
expression = expression.copy()
|
|
|
|
return self.sql(
|
|
|
|
exp.Substring(
|
|
|
|
this=expression.this, start=exp.Literal.number(1), length=expression.expression
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|
|
|
expression = expression.copy()
|
|
|
|
return self.sql(
|
|
|
|
exp.Substring(
|
|
|
|
this=expression.this,
|
|
|
|
start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
2025-02-13 15:01:11 +01:00
|
|
|
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
2025-02-13 15:01:11 +01:00
|
|
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
2025-02-13 15:02:59 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:42:50 +01:00
|
|
|
def min_or_least(self: Generator, expression: exp.Min) -> str:
|
|
|
|
name = "LEAST" if expression.expressions else "MIN"
|
|
|
|
return rename_func(name)(self, expression)
|
|
|
|
|
|
|
|
|
2025-02-13 15:47:04 +01:00
|
|
|
def max_or_greatest(self: Generator, expression: exp.Max) -> str:
|
|
|
|
name = "GREATEST" if expression.expressions else "MAX"
|
|
|
|
return rename_func(name)(self, expression)
|
|
|
|
|
|
|
|
|
2025-02-13 15:45:33 +01:00
|
|
|
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
|
|
|
|
cond = expression.this
|
|
|
|
|
|
|
|
if isinstance(expression.this, exp.Distinct):
|
|
|
|
cond = expression.this.expressions[0]
|
|
|
|
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
|
|
|
|
|
|
|
|
return self.func("sum", exp.func("if", cond, 1, 0))
|
|
|
|
|
|
|
|
|
2025-02-13 15:22:50 +01:00
|
|
|
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
2025-02-13 15:02:59 +01:00
|
|
|
target = self.sql(expression, "this")
|
|
|
|
trim_type = self.sql(expression, "position")
|
|
|
|
remove_chars = self.sql(expression, "expression")
|
|
|
|
collation = self.sql(expression, "collation")
|
|
|
|
|
|
|
|
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
|
|
|
|
if not remove_chars and not collation:
|
|
|
|
return self.trim_sql(expression)
|
|
|
|
|
|
|
|
trim_type = f"{trim_type} " if trim_type else ""
|
|
|
|
remove_chars = f"{remove_chars} " if remove_chars else ""
|
|
|
|
from_part = "FROM " if trim_type or remove_chars else ""
|
|
|
|
collation = f" COLLATE {collation}" if collation else ""
|
|
|
|
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
2025-02-13 15:28:28 +01:00
|
|
|
|
|
|
|
|
2025-02-13 15:56:32 +01:00
|
|
|
def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
|
2025-02-13 15:28:28 +01:00
|
|
|
return self.func("STRPTIME", expression.this, self.format_time(expression))
|
|
|
|
|
|
|
|
|
|
|
|
def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|
|
|
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
|
|
|
|
_dialect = Dialect.get_or_raise(dialect)
|
|
|
|
time_format = self.format_time(expression)
|
|
|
|
if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
|
|
|
|
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
|
|
|
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
|
|
|
|
|
|
|
return _ts_or_ds_to_date_sql
|
2025-02-13 15:56:32 +01:00
|
|
|
|
|
|
|
|
|
|
|
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
|
|
|
|
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
|
|
|
names = []
|
|
|
|
for agg in aggregations:
|
|
|
|
if isinstance(agg, exp.Alias):
|
|
|
|
names.append(agg.alias)
|
|
|
|
else:
|
|
|
|
"""
|
|
|
|
This case corresponds to aggregations without aliases being used as suffixes
|
|
|
|
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
|
|
|
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
|
|
|
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
|
|
|
"""
|
|
|
|
agg_all_unquoted = agg.transform(
|
|
|
|
lambda node: exp.Identifier(this=node.name, quoted=False)
|
|
|
|
if isinstance(node, exp.Identifier)
|
|
|
|
else node
|
|
|
|
)
|
|
|
|
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
|
|
|
|
|
|
|
|
return names
|