1
0
Fork 0

Merging upstream version 15.0.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:57:23 +01:00
parent 8deb804d23
commit fc63828ee4
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
167 changed files with 58268 additions and 51337 deletions

View file

@ -8,10 +8,16 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
if t.TYPE_CHECKING:
from sqlglot._typing import E
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
class Dialects(str, Enum):
@ -42,6 +48,19 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
def __eq__(cls, other: t.Any) -> bool:
if cls is other:
return True
if isinstance(other, str):
return cls is cls.get(other)
if isinstance(other, Dialect):
return cls is type(other)
return False
def __hash__(cls) -> int:
return hash(cls.__name__.lower())
@classmethod
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@ -70,17 +89,20 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
klass.bit_start, klass.bit_end = seq_get(
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
) or (None, None)
def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
return next(
(
(s, e)
for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
if t == token_type
),
(None, None),
)
klass.hex_start, klass.hex_end = seq_get(
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
) or (None, None)
klass.byte_start, klass.byte_end = seq_get(
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
) or (None, None)
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
return klass
@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
parser_class = None
generator_class = None
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
def __hash__(self) -> int:
return hash(type(self))
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"raw_start": self.raw_start,
"raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
return ""
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
@ -328,7 +358,7 @@ def var_map_sql(
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[t.Sequence], E]:
) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
Args:
@ -340,7 +370,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
def _format_time(args: t.Sequence):
def _format_time(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
) -> t.Callable[[t.List], E]:
def inner_func(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return inner_func
@ -390,8 +420,8 @@ def parse_date_delta(
def parse_date_delta_with_interval(
expression_class: t.Type[E],
) -> t.Callable[[t.Sequence], t.Optional[E]]:
def func(args: t.Sequence) -> t.Optional[E]:
) -> t.Callable[[t.List], t.Optional[E]]:
def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
return func
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: exp.Expression) -> str:
def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
for agg in aggregations:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
else:
"""
This case corresponds to aggregations without aliases being used as suffixes
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
lambda node: exp.Identifier(this=node.name, quoted=False)
if isinstance(node, exp.Identifier)
else node
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
return names