1
0
Fork 0

Merging upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:17:09 +01:00
parent d4fe7bdb16
commit 90988d8258
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions
sqlglot/dialects

View file

@ -1,14 +1,14 @@
from __future__ import annotations
import typing as t
from enum import Enum
from enum import Enum, auto
from functools import reduce
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@ -16,6 +16,9 @@ from sqlglot.trie import new_trie
B = t.TypeVar("B", bound=exp.Binary)
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
class Dialects(str, Enum):
DIALECT = ""
@ -43,6 +46,15 @@ class Dialects(str, Enum):
Doris = "doris"
class NormalizationStrategy(str, AutoName):
"""Specifies the strategy according to which identifiers should be normalized."""
LOWERCASE = auto() # Unquoted identifiers are lowercased
UPPERCASE = auto() # Unquoted identifiers are uppercased
CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
@ -106,26 +118,8 @@ class _Dialect(type):
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
dialect_properties = {
**{
k: v
for k, v in vars(klass).items()
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
},
"TOKENIZER_CLASS": klass.tokenizer_class,
}
if enum not in ("", "bigquery"):
dialect_properties["SELECT_KINDS"] = ()
# Pass required dialect properties to the tokenizer, parser and generator classes
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
for name, value in dialect_properties.items():
if hasattr(subclass, name):
setattr(subclass, name, value)
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
klass.generator_class.SELECT_KINDS = ()
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
@ -133,8 +127,6 @@ class _Dialect(type):
TokenType.SEMI,
}
klass.generator_class.can_identify = klass.can_identify
return klass
@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
# Determines whether or not unquoted identifiers are resolved as uppercase
# When set to None, it means that the dialect treats all identifiers as case-insensitive
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
# Specifies the strategy according to which identifiers should be normalized.
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect):
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
# Whether the behavior of a / b depends on the types of a and b.
# False means a / b is always float division.
# True means a / b is integer division if both a and b are integers.
TYPED_DIVISION = False
# False means 1 / 0 throws an error.
# True means 1 / 0 returns null.
SAFE_DIVISION = False
# A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
CONCAT_COALESCE = False
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect):
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
# Autofilled
# --- Autofilled ---
tokenizer_class = Tokenizer
parser_class = Parser
generator_class = Generator
@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
def __hash__(self) -> int:
return hash(type(self))
# Delimiters for bit, hex and byte literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
@classmethod
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
def get_or_raise(cls, dialect: DialectType) -> Dialect:
"""
Look up a dialect in the global dialect registry and return it if it exists.
Args:
dialect: The target dialect. If this is a string, it can be optionally followed by
additional key-value pairs that are separated by commas and are used to specify
dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
"""
if not dialect:
return cls
return cls()
if isinstance(dialect, _Dialect):
return dialect
return dialect()
if isinstance(dialect, Dialect):
return dialect.__class__
return dialect
if isinstance(dialect, str):
try:
dialect_name, *kv_pairs = dialect.split(",")
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
except ValueError:
raise ValueError(
f"Invalid dialect format: '{dialect}'. "
"Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
)
result = cls.get(dialect)
if not result:
raise ValueError(f"Unknown dialect '{dialect}'")
result = cls.get(dialect_name.strip())
if not result:
raise ValueError(f"Unknown dialect '{dialect_name}'.")
return result
return result(**kwargs)
raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
@classmethod
def format_time(
@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect):
return expression
@classmethod
def normalize_identifier(cls, expression: E) -> E:
def __init__(self, **kwargs) -> None:
normalization_strategy = kwargs.get("normalization_strategy")
if normalization_strategy is None:
self.normalization_strategy = self.NORMALIZATION_STRATEGY
else:
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
def __eq__(self, other: t.Any) -> bool:
# Does not currently take dialect state into account
return type(self) == other
def __hash__(self) -> int:
# Does not currently take dialect state into account
return hash(type(self))
def normalize_identifier(self, expression: E) -> E:
"""
Normalizes an unquoted identifier to either lower or upper case, thus essentially
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
they will be normalized to lowercase regardless of being quoted or not.
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO would be resolved as foo in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are
present, and dialects like MySQL, whose resolution rules match those employed by the
underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags,
like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so
that it can analyze queries in the optimizer and successfully capture their semantics.
"""
if isinstance(expression, exp.Identifier) and (
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
if (
isinstance(expression, exp.Identifier)
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
)
):
expression.set(
"this",
expression.this.upper()
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else expression.this.lower(),
)
return expression
@classmethod
def case_sensitive(cls, text: str) -> bool:
def case_sensitive(self, text: str) -> bool:
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
return False
unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
unsafe = (
str.islower
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else str.isupper
)
return any(unsafe(char) for char in text)
@classmethod
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
"""Checks if text can be identified given an identify option.
Args:
@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect):
return True
if identify == "safe":
return not cls.case_sensitive(text)
return not self.case_sensitive(text)
return False
@classmethod
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
def quote_identifier(self, expression: E, identify: bool = True) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class()
self._tokenizer = self.tokenizer_class(dialect=self)
return self._tokenizer
def parser(self, **opts) -> Parser:
return self.parser_class(**opts)
return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> Generator:
return self.generator_class(**opts)
return self.generator_class(dialect=self, **opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex
return self.func(name, expression.this, expression.expression)
return _arg_max_or_min_sql
def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
this = expression.this.copy()
return_type = expression.return_type
if return_type.is_type(exp.DataType.Type.DATE):
# If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
# can truncate timestamp strings, because some dialects can't cast them to DATE
this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
expression.this.replace(exp.cast(this, return_type))
return expression
def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
if cast and isinstance(expression, exp.TsOrDsAdd):
expression = ts_or_ds_add_cast(expression)
return self.func(
name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
)
return _delta_sql