Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ebba7c6a18
commit
d26905e4af
187 changed files with 86502 additions and 71397 deletions
|
@ -31,6 +31,7 @@ class Dialects(str, Enum):
|
|||
|
||||
DIALECT = ""
|
||||
|
||||
ATHENA = "athena"
|
||||
BIGQUERY = "bigquery"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
DATABRICKS = "databricks"
|
||||
|
@ -42,6 +43,7 @@ class Dialects(str, Enum):
|
|||
ORACLE = "oracle"
|
||||
POSTGRES = "postgres"
|
||||
PRESTO = "presto"
|
||||
PRQL = "prql"
|
||||
REDSHIFT = "redshift"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SPARK = "spark"
|
||||
|
@ -108,11 +110,18 @@ class _Dialect(type):
|
|||
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
|
||||
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
|
||||
|
||||
klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
|
||||
base = seq_get(bases, 0)
|
||||
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
|
||||
base_parser = (getattr(base, "parser_class", Parser),)
|
||||
base_generator = (getattr(base, "generator_class", Generator),)
|
||||
|
||||
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
|
||||
klass.parser_class = getattr(klass, "Parser", Parser)
|
||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||
klass.tokenizer_class = klass.__dict__.get(
|
||||
"Tokenizer", type("Tokenizer", base_tokenizer, {})
|
||||
)
|
||||
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
|
||||
klass.generator_class = klass.__dict__.get(
|
||||
"Generator", type("Generator", base_generator, {})
|
||||
)
|
||||
|
||||
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
|
||||
|
@ -134,9 +143,31 @@ class _Dialect(type):
|
|||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
|
||||
|
||||
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
|
||||
klass.UNESCAPED_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
"\\\\": "\\",
|
||||
**klass.UNESCAPED_SEQUENCES,
|
||||
}
|
||||
|
||||
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if enum not in ("", "databricks", "hive", "spark", "spark2"):
|
||||
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
|
||||
for modifier in ("cluster", "distribute", "sort"):
|
||||
modifier_transforms.pop(modifier, None)
|
||||
|
||||
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
TokenType.ANTI,
|
||||
|
@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
|
|||
False: Disables function name normalization.
|
||||
"""
|
||||
|
||||
LOG_BASE_FIRST = True
|
||||
"""Whether the base comes first in the `LOG` function."""
|
||||
LOG_BASE_FIRST: t.Optional[bool] = True
|
||||
"""
|
||||
Whether the base comes first in the `LOG` function.
|
||||
Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
|
||||
"""
|
||||
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
"""
|
||||
|
@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
|
|||
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
|
||||
"""
|
||||
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
"""Mapping of an unescaped escape sequence to the corresponding character."""
|
||||
UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
"""Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
|
||||
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
"""
|
||||
|
@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
|
|||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Delimiters for string literals and identifiers
|
||||
QUOTE_START = "'"
|
||||
|
@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
|
|||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
def str_position_sql(
|
||||
self: Generator, expression: exp.StrPosition, generate_instance: bool = False
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
position = self.sql(expression, "position")
|
||||
instance = expression.args.get("instance") if generate_instance else None
|
||||
position_offset = ""
|
||||
|
||||
if position:
|
||||
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
|
||||
return f"STRPOS({this}, {substr})"
|
||||
# Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
|
||||
this = self.func("SUBSTR", this, position)
|
||||
position_offset = f" + {position} - 1"
|
||||
|
||||
return self.func("STRPOS", this, substr, instance) + position_offset
|
||||
|
||||
|
||||
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
||||
|
@ -689,9 +731,7 @@ def build_date_delta_with_interval(
|
|||
if expression and expression.is_string:
|
||||
expression = exp.Literal.number(expression.this)
|
||||
|
||||
return expression_class(
|
||||
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
|
||||
)
|
||||
return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
|
||||
|
||||
return _builder
|
||||
|
||||
|
@ -710,18 +750,14 @@ def date_add_interval_sql(
|
|||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def func(self: Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
interval = exp.Interval(this=expression.expression, unit=unit)
|
||||
interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
|
||||
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
||||
return self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
|
||||
)
|
||||
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
|
||||
|
||||
|
||||
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
||||
|
@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
|
||||
return self.func(
|
||||
name,
|
||||
exp.var(expression.text("unit").upper() or "DAY"),
|
||||
unit_to_var(expression),
|
||||
expression.expression,
|
||||
expression.this,
|
||||
)
|
||||
|
@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
return _delta_sql
|
||||
|
||||
|
||||
def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
if isinstance(unit, exp.Placeholder):
|
||||
return unit
|
||||
if unit:
|
||||
return exp.Literal.string(unit.name)
|
||||
return exp.Literal.string(default) if default else None
|
||||
|
||||
|
||||
def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
if isinstance(unit, (exp.Var, exp.Placeholder)):
|
||||
return unit
|
||||
return exp.Var(this=default) if default else None
|
||||
|
||||
|
||||
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
||||
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
|
||||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
|
@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
|
||||
|
||||
def build_json_extract_path(
|
||||
expr_type: t.Type[F], zero_based_indexing: bool = True
|
||||
expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
|
||||
) -> t.Callable[[t.List], F]:
|
||||
def _builder(args: t.List) -> F:
|
||||
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
|
||||
|
@ -1018,7 +1072,11 @@ def build_json_extract_path(
|
|||
|
||||
# This is done to avoid failing in the expression validator due to the arg count
|
||||
del args[2:]
|
||||
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
|
||||
return expr_type(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.JSONPath(expressions=segments),
|
||||
only_json_types=arrow_req_json_type,
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
|
|||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
|
||||
return self.sql(exp.Array(expressions=[filtered]))
|
||||
|
||||
|
||||
def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
|
||||
return self.func(
|
||||
"TO_NUMBER",
|
||||
expression.this,
|
||||
expression.args.get("format"),
|
||||
expression.args.get("nlsparam"),
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue