1
0
Fork 0

Merging upstream version 23.7.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:30:28 +01:00
parent ebba7c6a18
commit d26905e4af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
187 changed files with 86502 additions and 71397 deletions

View file

@ -31,6 +31,7 @@ class Dialects(str, Enum):
DIALECT = ""
ATHENA = "athena"
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
@ -42,6 +43,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
PRQL = "prql"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
@ -108,11 +110,18 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
base = seq_get(bases, 0)
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
base_parser = (getattr(base, "parser_class", Parser),)
base_generator = (getattr(base, "generator_class", Generator),)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer_class = klass.__dict__.get(
"Tokenizer", type("Tokenizer", base_tokenizer, {})
)
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
klass.generator_class = klass.__dict__.get(
"Generator", type("Generator", base_generator, {})
)
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
@ -134,9 +143,31 @@ class _Dialect(type):
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
klass.UNESCAPED_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
"\\\\": "\\",
**klass.UNESCAPED_SEQUENCES,
}
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
for modifier in ("cluster", "distribute", "sort"):
modifier_transforms.pop(modifier, None)
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
False: Disables function name normalization.
"""
LOG_BASE_FIRST = True
"""Whether the base comes first in the `LOG` function."""
LOG_BASE_FIRST: t.Optional[bool] = True
"""
Whether the base comes first in the `LOG` function.
Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
"""
NULL_ORDERING = "nulls_are_small"
"""
@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
"""Mapping of an unescaped escape sequence to the corresponding character."""
UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
"""Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
PSEUDOCOLUMNS: t.Set[str] = set()
"""
@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for string literals and identifiers
QUOTE_START = "'"
@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
return ""
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def str_position_sql(
self: Generator, expression: exp.StrPosition, generate_instance: bool = False
) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
instance = expression.args.get("instance") if generate_instance else None
position_offset = ""
if position:
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
return f"STRPOS({this}, {substr})"
# Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
this = self.func("SUBSTR", this, position)
position_offset = f" + {position} - 1"
return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
@ -689,9 +731,7 @@ def build_date_delta_with_interval(
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
return expression_class(
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
return _builder
@ -710,18 +750,14 @@ def date_add_interval_sql(
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression, unit=unit)
interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
"DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return self.func(
name,
exp.var(expression.text("unit").upper() or "DAY"),
unit_to_var(expression),
expression.expression,
expression.this,
)
@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
unit = expression.args.get("unit")
if isinstance(unit, exp.Placeholder):
return unit
if unit:
return exp.Literal.string(unit.name)
return exp.Literal.string(default) if default else None
def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
unit = expression.args.get("unit")
if isinstance(unit, (exp.Var, exp.Placeholder)):
return unit
return exp.Var(this=default) if default else None
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
def build_json_extract_path(
expr_type: t.Type[F], zero_based_indexing: bool = True
expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
@ -1018,7 +1072,11 @@ def build_json_extract_path(
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
return expr_type(
this=seq_get(args, 0),
expression=exp.JSONPath(expressions=segments),
only_json_types=arrow_req_json_type,
)
return _builder
@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("nlsparam"),
)