Merging upstream version 11.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fdac67ef7f
commit
ba0f3f0bfa
112 changed files with 126100 additions and 230 deletions
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
def _date_add(expression_class):
|
||||
|
||||
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
||||
def func(args):
|
||||
interval = seq_get(args, 1)
|
||||
return expression_class(
|
||||
|
@ -27,26 +31,26 @@ def _date_add(expression_class):
|
|||
return func
|
||||
|
||||
|
||||
def _date_trunc(args):
|
||||
def _date_trunc(args: t.Sequence) -> exp.Expression:
|
||||
unit = seq_get(args, 1)
|
||||
if isinstance(unit, exp.Column):
|
||||
unit = exp.Var(this=unit.name)
|
||||
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||
|
||||
|
||||
def _date_add_sql(data_type, kind):
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit") or "'day'"
|
||||
expression = self.sql(expression, "expression")
|
||||
return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
|
||||
return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _derived_table_values_to_unnest(self, expression):
|
||||
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||
if not isinstance(expression.unnest().parent, exp.From):
|
||||
expression = transforms.remove_precision_parameterized_types(expression)
|
||||
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
|
||||
return self.values_sql(expression)
|
||||
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||
structs = []
|
||||
|
@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression):
|
|||
return self.unnest_sql(unnest_exp)
|
||||
|
||||
|
||||
def _returnsproperty_sql(self, expression):
|
||||
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Schema):
|
||||
this = f"{this.this} <{self.expressions(this)}>"
|
||||
|
@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression):
|
|||
return f"RETURNS {this}"
|
||||
|
||||
|
||||
def _create_sql(self, expression):
|
||||
kind = expression.args.get("kind")
|
||||
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
|
||||
kind = expression.args["kind"]
|
||||
returns = expression.find(exp.ReturnsProperty)
|
||||
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
|
||||
expression = expression.copy()
|
||||
|
@ -89,6 +93,29 @@ def _create_sql(self, expression):
|
|||
return self.create_sql(expression)
|
||||
|
||||
|
||||
def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Remove references to unnest table aliases since bigquery doesn't allow them.
|
||||
|
||||
These are added by the optimizer's qualify_column step.
|
||||
"""
|
||||
if isinstance(expression, exp.Select):
|
||||
unnests = {
|
||||
unnest.alias
|
||||
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
|
||||
if isinstance(unnest, exp.Unnest) and unnest.alias
|
||||
}
|
||||
|
||||
if unnests:
|
||||
expression = expression.copy()
|
||||
|
||||
for select in expression.expressions:
|
||||
for column in select.find_all(exp.Column):
|
||||
if column.table in unnests:
|
||||
column.set("table", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
unnest_column_only = True
|
||||
time_mapping = {
|
||||
|
@ -110,7 +137,7 @@ class BigQuery(Dialect):
|
|||
]
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
|
@ -190,6 +217,9 @@ class BigQuery(Dialect):
|
|||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
|
|
@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _lower_func(sql):
|
||||
def _lower_func(sql: str) -> str:
|
||||
index = sql.index("(")
|
||||
return sql[:index].lower() + sql[index:]
|
||||
|
||||
|
|
|
@ -11,6 +11,8 @@ from sqlglot.time import format_time
|
|||
from sqlglot.tokens import Tokenizer
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
DIALECT = ""
|
||||
|
@ -37,14 +39,16 @@ class Dialects(str, Enum):
|
|||
|
||||
|
||||
class _Dialect(type):
|
||||
classes: t.Dict[str, Dialect] = {}
|
||||
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, key):
|
||||
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
||||
return cls.classes[key]
|
||||
|
||||
@classmethod
|
||||
def get(cls, key, default=None):
|
||||
def get(
|
||||
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
|
||||
) -> t.Optional[t.Type[Dialect]]:
|
||||
return cls.classes.get(key, default)
|
||||
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
|
@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
|
|||
generator_class = None
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect):
|
||||
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||
if not dialect:
|
||||
return cls
|
||||
if isinstance(dialect, _Dialect):
|
||||
|
@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def format_time(cls, expression):
|
||||
def format_time(
|
||||
cls, expression: t.Optional[str | exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if isinstance(expression, str):
|
||||
return exp.Literal.string(
|
||||
format_time(
|
||||
|
@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
|
|||
)
|
||||
return expression
|
||||
|
||||
def parse(self, sql, **opts):
|
||||
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
||||
|
||||
def parse_into(self, expression_type, sql, **opts):
|
||||
def parse_into(
|
||||
self, expression_type: exp.IntoType, sql: str, **opts
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
|
||||
|
||||
def generate(self, expression, **opts):
|
||||
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
|
||||
return self.generator(**opts).generate(expression)
|
||||
|
||||
def transpile(self, code, **opts):
|
||||
return self.generate(self.parse(code), **opts)
|
||||
def transpile(self, sql: str, **opts) -> t.List[str]:
|
||||
return [self.generate(expression, **opts) for expression in self.parse(sql)]
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
self._tokenizer = self.tokenizer_class()
|
||||
self._tokenizer = self.tokenizer_class() # type: ignore
|
||||
return self._tokenizer
|
||||
|
||||
def parser(self, **opts):
|
||||
return self.parser_class(
|
||||
def parser(self, **opts) -> Parser:
|
||||
return self.parser_class( # type: ignore
|
||||
**{
|
||||
"index_offset": self.index_offset,
|
||||
"unnest_column_only": self.unnest_column_only,
|
||||
|
@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
|
|||
},
|
||||
)
|
||||
|
||||
def generator(self, **opts):
|
||||
return self.generator_class(
|
||||
def generator(self, **opts) -> Generator:
|
||||
return self.generator_class( # type: ignore
|
||||
**{
|
||||
"quote_start": self.quote_start,
|
||||
"quote_end": self.quote_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"escape": self.tokenizer_class.ESCAPES[0],
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||
"index_offset": self.index_offset,
|
||||
"time_mapping": self.inverse_time_mapping,
|
||||
"time_trie": self.inverse_time_trie,
|
||||
|
@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
|
|||
)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
|
||||
|
||||
def rename_func(name):
|
||||
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def _rename(self, expression):
|
||||
args = flatten(expression.args.values())
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
|
@ -214,32 +222,34 @@ def rename_func(name):
|
|||
return _rename
|
||||
|
||||
|
||||
def approx_count_distinct_sql(self, expression):
|
||||
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
|
||||
if expression.args.get("accuracy"):
|
||||
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
||||
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
||||
|
||||
|
||||
def if_sql(self, expression):
|
||||
def if_sql(self: Generator, expression: exp.If) -> str:
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"IF({expressions})"
|
||||
|
||||
|
||||
def arrow_json_extract_sql(self, expression):
|
||||
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||
return self.binary(expression, "->")
|
||||
|
||||
|
||||
def arrow_json_extract_scalar_sql(self, expression):
|
||||
def arrow_json_extract_scalar_sql(
|
||||
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
|
||||
) -> str:
|
||||
return self.binary(expression, "->>")
|
||||
|
||||
|
||||
def inline_array_sql(self, expression):
|
||||
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
||||
return f"[{self.expressions(expression)}]"
|
||||
|
||||
|
||||
def no_ilike_sql(self, expression):
|
||||
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||
return self.like_sql(
|
||||
exp.Like(
|
||||
this=exp.Lower(this=expression.this),
|
||||
|
@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
|
|||
)
|
||||
|
||||
|
||||
def no_paren_current_date_sql(self, expression):
|
||||
def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
|
||||
zone = self.sql(expression, "this")
|
||||
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
|
||||
|
||||
|
||||
def no_recursive_cte_sql(self, expression):
|
||||
def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
|
||||
if expression.args.get("recursive"):
|
||||
self.unsupported("Recursive CTEs are unsupported")
|
||||
expression.args["recursive"] = False
|
||||
return self.with_sql(expression)
|
||||
|
||||
|
||||
def no_safe_divide_sql(self, expression):
|
||||
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
|
||||
n = self.sql(expression, "this")
|
||||
d = self.sql(expression, "expression")
|
||||
return f"IF({d} <> 0, {n} / {d}, NULL)"
|
||||
|
||||
|
||||
def no_tablesample_sql(self, expression):
|
||||
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
||||
self.unsupported("TABLESAMPLE unsupported")
|
||||
return self.sql(expression.this)
|
||||
|
||||
|
||||
def no_pivot_sql(self, expression):
|
||||
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
|
||||
self.unsupported("PIVOT unsupported")
|
||||
return self.sql(expression)
|
||||
|
||||
|
||||
def no_trycast_sql(self, expression):
|
||||
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression)
|
||||
|
||||
|
||||
def no_properties_sql(self, expression):
|
||||
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
||||
self.unsupported("Properties unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self, expression):
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
position = self.sql(expression, "position")
|
||||
|
@ -294,13 +304,15 @@ def str_position_sql(self, expression):
|
|||
return f"STRPOS({this}, {substr})"
|
||||
|
||||
|
||||
def struct_extract_sql(self, expression):
|
||||
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
|
||||
return f"{this}.{struct_key}"
|
||||
|
||||
|
||||
def var_map_sql(self, expression, map_func_name="MAP"):
|
||||
def var_map_sql(
|
||||
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
|
||||
) -> str:
|
||||
keys = expression.args["keys"]
|
||||
values = expression.args["values"]
|
||||
|
||||
|
@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
|
|||
return f"{map_func_name}({self.format_args(*args)})"
|
||||
|
||||
|
||||
def format_time_lambda(exp_class, dialect, default=None):
|
||||
def format_time_lambda(
|
||||
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
"""Helper used for time expressions.
|
||||
|
||||
Args
|
||||
exp_class (Class): the expression class to instantiate
|
||||
dialect (string): sql dialect
|
||||
default (Option[bool | str]): the default format, True being time
|
||||
Args:
|
||||
exp_class: the expression class to instantiate.
|
||||
dialect: target sql dialect.
|
||||
default: the default format, True being time.
|
||||
|
||||
Returns:
|
||||
A callable that can be used to return the appropriately formatted time expression.
|
||||
"""
|
||||
|
||||
def _format_time(args):
|
||||
def _format_time(args: t.Sequence):
|
||||
return exp_class(
|
||||
this=seq_get(args, 0),
|
||||
format=Dialect[dialect].format_time(
|
||||
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
||||
seq_get(args, 1)
|
||||
or (Dialect[dialect].time_format if default is True else default or None)
|
||||
),
|
||||
)
|
||||
|
||||
return _format_time
|
||||
|
||||
|
||||
def create_with_partitions_sql(self, expression):
|
||||
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
||||
"""
|
||||
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
||||
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
|
||||
|
@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
|
|||
return self.create_sql(expression)
|
||||
|
||||
|
||||
def parse_date_delta(exp_class, unit_mapping=None):
|
||||
def inner_func(args):
|
||||
def parse_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
def inner_func(args: t.Sequence) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
|
||||
return exp_class(this=this, expression=expression, unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
def locate_to_strposition(args):
|
||||
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
|
@ -379,22 +399,22 @@ def locate_to_strposition(args):
|
|||
)
|
||||
|
||||
|
||||
def strposition_to_locate_sql(self, expression):
|
||||
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
args = self.format_args(
|
||||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
)
|
||||
return f"LOCATE({args})"
|
||||
|
||||
|
||||
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
||||
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def trim_sql(self, expression):
|
||||
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
remove_chars = self.sql(expression, "expression")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
|
@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1 and args[0].is_number:
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
return format_time_lambda(exp.StrToTime, "drill")(args)
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
expression = self.sql(expression, "expression")
|
||||
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
|
||||
unit = exp.Var(this=expression.text("unit").upper() or "DAY")
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def if_sql(self, expression):
|
||||
def if_sql(self: generator.Generator, expression: exp.If) -> str:
|
||||
"""
|
||||
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||
adds the backticks around the keyword IF.
|
||||
|
@ -61,7 +56,7 @@ def if_sql(self, expression):
|
|||
return f"`IF`({expressions})"
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Drill.date_format:
|
||||
|
@ -111,7 +106,7 @@ class Drill(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'"]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -168,10 +163,10 @@ class Drill(Dialect):
|
|||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
||||
def normalize_func(self, name):
|
||||
def normalize_func(self, name: str) -> str:
|
||||
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
||||
|
|
|
@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression):
|
|||
|
||||
|
||||
def _ts_or_ds_add(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
e = self.sql(expression, "expression")
|
||||
this = expression.args.get("this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
|
@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression):
|
|||
|
||||
def _date_add(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
e = self.sql(expression, "expression")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"{this} + INTERVAL {e} {unit}"
|
||||
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _array_sort_sql(self, expression):
|
||||
|
|
|
@ -172,7 +172,7 @@ class Hive(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", '"']
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
KEYWORDS = {
|
||||
|
|
|
@ -89,8 +89,9 @@ def _date_add_sql(kind):
|
|||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
expression = self.sql(expression, "expression")
|
||||
return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -117,7 +118,7 @@ class MySQL(Dialect):
|
|||
QUOTES = ["'", '"']
|
||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["'", "\\"]
|
||||
STRING_ESCAPES = ["'", "\\"]
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
||||
|
||||
|
|
|
@ -40,8 +40,7 @@ def _date_add_sql(kind):
|
|||
|
||||
expression = expression.copy()
|
||||
expression.args["is_string"] = True
|
||||
expression = self.sql(expression)
|
||||
return f"{this} {kind} INTERVAL {expression} {unit}"
|
||||
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
|
||||
|
||||
return func
|
||||
|
||||
|
|
|
@ -37,11 +37,10 @@ class Redshift(Postgres):
|
|||
return this
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||
"COPY": TokenType.COMMAND,
|
||||
"ENCODE": TokenType.ENCODE,
|
||||
"GEOMETRY": TokenType.GEOMETRY,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
|
|
|
@ -180,7 +180,7 @@ class Snowflake(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
ESCAPES = ["\\", "'"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -191,6 +191,7 @@ class Snowflake(Dialect):
|
|||
**tokens.Tokenizer.KEYWORDS,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"PUT": TokenType.COMMAND,
|
||||
"RENAME": TokenType.REPLACE,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
|
@ -222,6 +223,7 @@ class Snowflake(Dialect):
|
|||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -294,3 +296,12 @@ class Snowflake(Dialect):
|
|||
kind = f" {kind_value}" if kind_value else ""
|
||||
this = f" {self.sql(expression, 'this')}"
|
||||
return f"DESCRIBE{kind}{this}"
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
start = expression.args.get("start")
|
||||
start = f" START {start}" if start else ""
|
||||
increment = expression.args.get("increment")
|
||||
increment = f" INCREMENT {increment}" if increment else ""
|
||||
return f"AUTOINCREMENT{start}{increment}"
|
||||
|
|
|
@ -157,6 +157,7 @@ class Spark(Hive):
|
|||
TRANSFORMS.pop(exp.ILike)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
|
|
|
@ -49,7 +49,6 @@ class SQLite(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue