Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
8deb804d23
commit
fc63828ee4
167 changed files with 58268 additions and 51337 deletions
|
@ -8,10 +8,16 @@ from sqlglot.generator import Generator
|
|||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Token, Tokenizer
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
|
@ -42,6 +48,19 @@ class Dialects(str, Enum):
|
|||
class _Dialect(type):
|
||||
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||
|
||||
def __eq__(cls, other: t.Any) -> bool:
|
||||
if cls is other:
|
||||
return True
|
||||
if isinstance(other, str):
|
||||
return cls is cls.get(other)
|
||||
if isinstance(other, Dialect):
|
||||
return cls is type(other)
|
||||
|
||||
return False
|
||||
|
||||
def __hash__(cls) -> int:
|
||||
return hash(cls.__name__.lower())
|
||||
|
||||
@classmethod
|
||||
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
||||
return cls.classes[key]
|
||||
|
@ -70,17 +89,20 @@ class _Dialect(type):
|
|||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
klass.bit_start, klass.bit_end = seq_get(
|
||||
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
|
||||
return next(
|
||||
(
|
||||
(s, e)
|
||||
for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
|
||||
if t == token_type
|
||||
),
|
||||
(None, None),
|
||||
)
|
||||
|
||||
klass.hex_start, klass.hex_end = seq_get(
|
||||
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
|
||||
klass.byte_start, klass.byte_end = seq_get(
|
||||
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
|
||||
) or (None, None)
|
||||
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
|
||||
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
|
||||
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
|
||||
|
||||
return klass
|
||||
|
||||
|
@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect):
|
|||
parser_class = None
|
||||
generator_class = None
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self))
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||
if not dialect:
|
||||
|
@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect):
|
|||
"hex_end": self.hex_end,
|
||||
"byte_start": self.byte_start,
|
||||
"byte_end": self.byte_end,
|
||||
"raw_start": self.raw_start,
|
||||
"raw_end": self.raw_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
|
@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
|||
|
||||
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
|
||||
self.unsupported("PIVOT unsupported")
|
||||
return self.sql(expression)
|
||||
return ""
|
||||
|
||||
|
||||
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
||||
|
@ -328,7 +358,7 @@ def var_map_sql(
|
|||
|
||||
def format_time_lambda(
|
||||
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
) -> t.Callable[[t.List], E]:
|
||||
"""Helper used for time expressions.
|
||||
|
||||
Args:
|
||||
|
@ -340,7 +370,7 @@ def format_time_lambda(
|
|||
A callable that can be used to return the appropriately formatted time expression.
|
||||
"""
|
||||
|
||||
def _format_time(args: t.Sequence):
|
||||
def _format_time(args: t.List):
|
||||
return exp_class(
|
||||
this=seq_get(args, 0),
|
||||
format=Dialect[dialect].format_time(
|
||||
|
@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
|||
|
||||
def parse_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
) -> t.Callable[[t.Sequence], E]:
|
||||
def inner_func(args: t.Sequence) -> E:
|
||||
) -> t.Callable[[t.List], E]:
|
||||
def inner_func(args: t.List) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = args[2] if unit_based else seq_get(args, 0)
|
||||
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
||||
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
@ -390,8 +420,8 @@ def parse_date_delta(
|
|||
|
||||
def parse_date_delta_with_interval(
|
||||
expression_class: t.Type[E],
|
||||
) -> t.Callable[[t.Sequence], t.Optional[E]]:
|
||||
def func(args: t.Sequence) -> t.Optional[E]:
|
||||
) -> t.Callable[[t.List], t.Optional[E]]:
|
||||
def func(args: t.List) -> t.Optional[E]:
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
|
@ -409,7 +439,7 @@ def parse_date_delta_with_interval(
|
|||
return func
|
||||
|
||||
|
||||
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
||||
|
@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
|||
)
|
||||
|
||||
|
||||
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
||||
def locate_to_strposition(args: t.List) -> exp.Expression:
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
|
@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
|
||||
|
||||
|
||||
def str_to_time_sql(self, expression: exp.Expression) -> str:
|
||||
def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
|
||||
return self.func("STRPTIME", expression.this, self.format_time(expression))
|
||||
|
||||
|
||||
|
@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
|
||||
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
||||
names = []
|
||||
for agg in aggregations:
|
||||
if isinstance(agg, exp.Alias):
|
||||
names.append(agg.alias)
|
||||
else:
|
||||
"""
|
||||
This case corresponds to aggregations without aliases being used as suffixes
|
||||
(e.g. col_avg(foo)). We need to unquote identifiers because they're going to
|
||||
be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
|
||||
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
|
||||
"""
|
||||
agg_all_unquoted = agg.transform(
|
||||
lambda node: exp.Identifier(this=node.name, quoted=False)
|
||||
if isinstance(node, exp.Identifier)
|
||||
else node
|
||||
)
|
||||
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
|
||||
|
||||
return names
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue