1
0
Fork 0

Merging upstream version 25.5.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:41:14 +01:00
parent 298e7a8147
commit 029b9c2c73
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
136 changed files with 80990 additions and 72541 deletions

View file

@ -9,7 +9,7 @@ from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@ -122,15 +122,21 @@ class _Dialect(type):
)
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
base = seq_get(bases, 0)
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
base_parser = (getattr(base, "parser_class", Parser),)
base_generator = (getattr(base, "generator_class", Generator),)
klass.tokenizer_class = klass.__dict__.get(
"Tokenizer", type("Tokenizer", base_tokenizer, {})
)
klass.jsonpath_tokenizer_class = klass.__dict__.get(
"JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {})
)
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
klass.generator_class = klass.__dict__.get(
"Generator", type("Generator", base_generator, {})
@ -164,6 +170,8 @@ class _Dialect(type):
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
@ -232,6 +240,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""
COPY_PARAMS_ARE_CSV = True
"""Separator of COPY statement parameters."""
NORMALIZE_FUNCTIONS: bool | str = "upper"
"""
Determines how function names are going to be normalized.
@ -311,9 +322,44 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""
COPY_PARAMS_ARE_CSV = True
"""
Whether COPY statement parameters are separated by comma or whitespace
"""
FORCE_EARLY_ALIAS_REF_EXPANSION = False
"""
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS (
SELECT
1 AS id,
2 AS my_id
)
SELECT
id AS my_id
FROM
data
WHERE
my_id = 1
GROUP BY
my_id,
HAVING
my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
- BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
- Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
"""
EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
"""Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
# --- Autofilled ---
tokenizer_class = Tokenizer
jsonpath_tokenizer_class = JSONPathTokenizer
parser_class = Parser
generator_class = Generator
@ -323,6 +369,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
INVERSE_FORMAT_TRIE: t.Dict = {}
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
@ -342,8 +390,99 @@ class Dialect(metaclass=_Dialect):
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None
# Separator of COPY statement parameters
COPY_PARAMS_ARE_CSV = True
DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"DAY OF WEEK": "DAYOFWEEK",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"DAY OF YEAR": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MSECS": "MILLISECOND",
"MSECOND": "MILLISECOND",
"MSECONDS": "MILLISECOND",
"MILLISEC": "MILLISECOND",
"MILLISECS": "MILLISECOND",
"MILLISECON": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"USECS": "MICROSECOND",
"MICROSEC": "MICROSECOND",
"MICROSECS": "MICROSECOND",
"USECOND": "MICROSECOND",
"USECONDS": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH_SECOND": "EPOCH",
"EPOCH_SECONDS": "EPOCH",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
"DEC": "DECADE",
"DECS": "DECADE",
"DECADES": "DECADE",
"MIL": "MILLENIUM",
"MILS": "MILLENIUM",
"MILLENIA": "MILLENIUM",
"C": "CENTURY",
"CENT": "CENTURY",
"CENTS": "CENTURY",
"CENTURIES": "CENTURY",
}
@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
@ -371,8 +510,28 @@ class Dialect(metaclass=_Dialect):
return dialect
if isinstance(dialect, str):
try:
dialect_name, *kv_pairs = dialect.split(",")
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
dialect_name, *kv_strings = dialect.split(",")
kv_pairs = (kv.split("=") for kv in kv_strings)
kwargs = {}
for pair in kv_pairs:
key = pair[0].strip()
value: t.Union[bool | str | None] = None
if len(pair) == 1:
# Default initialize standalone settings to True
value = True
elif len(pair) == 2:
value = pair[1].strip()
# Coerce the value to boolean if it matches to the truthy/falsy values below
value_lower = value.lower()
if value_lower in ("true", "1"):
value = True
elif value_lower in ("false", "0"):
value = False
kwargs[key] = value
except ValueError:
raise ValueError(
f"Invalid dialect format: '{dialect}'. "
@ -410,13 +569,15 @@ class Dialect(metaclass=_Dialect):
return expression
def __init__(self, **kwargs) -> None:
normalization_strategy = kwargs.get("normalization_strategy")
normalization_strategy = kwargs.pop("normalization_strategy", None)
if normalization_strategy is None:
self.normalization_strategy = self.NORMALIZATION_STRATEGY
else:
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
self.settings = kwargs
def __eq__(self, other: t.Any) -> bool:
# Does not currently take dialect state into account
return type(self) == other
@ -518,9 +679,8 @@ class Dialect(metaclass=_Dialect):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"
try:
return parse_json_path(path_text)
return parse_json_path(path_text, self)
except ParseError as e:
logger.warning(f"Invalid JSON path syntax. {str(e)}")
@ -548,9 +708,11 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class(dialect=self)
return self._tokenizer
return self.tokenizer_class(dialect=self)
@property
def jsonpath_tokenizer(self) -> JSONPathTokenizer:
return self.jsonpath_tokenizer_class(dialect=self)
def parser(self, **opts) -> Parser:
return self.parser_class(dialect=self, **opts)
@ -739,13 +901,17 @@ def time_format(
def build_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
exp_class: t.Type[E],
unit_mapping: t.Optional[t.Dict[str, str]] = None,
default_unit: t.Optional[str] = "DAY",
) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
unit = None
if unit_based or default_unit:
unit = args[0] if unit_based else exp.Literal.string(default_unit)
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return _builder
@ -803,19 +969,45 @@ def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.Timesta
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
if not expression.expression:
zone = expression.args.get("zone")
if not zone:
from sqlglot.optimizer.annotate_types import annotate_types
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
return self.sql(exp.cast(expression.this, target_type))
if expression.text("expression").lower() in TIMEZONES:
if zone.name.lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
zone=expression.expression,
zone=zone,
)
)
return self.func("TIMESTAMP", expression.this, expression.expression)
return self.func("TIMESTAMP", expression.this, zone)
def no_time_sql(self: Generator, expression: exp.Time) -> str:
# Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
expr = exp.cast(
exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
)
return self.sql(expr)
def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
this = expression.this
expr = expression.expression
if expr.name.lower() in TIMEZONES:
# Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
return self.sql(this)
this = exp.cast(this, exp.DataType.Type.DATE)
expr = exp.cast(expr, exp.DataType.Type.TIME)
return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
def locate_to_strposition(args: t.List) -> exp.Expression:
@ -1058,6 +1250,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
return exp.Var(this=default) if default else None
@t.overload
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
pass
@t.overload
def map_date_part(
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
) -> t.Optional[exp.Expression]:
pass
def map_date_part(part, dialect: DialectType = Dialect):
mapped = (
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
)
return exp.var(mapped) if mapped else part
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")