Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d4fe7bdb16
commit
90988d8258
127 changed files with 73384 additions and 73067 deletions
sqlglot/dialects
|
@ -1,14 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
from functools import reduce
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.helper import AutoName, flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import TIMEZONES, format_time
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
@ -16,6 +16,9 @@ from sqlglot.trie import new_trie
|
|||
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
|
||||
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
DIALECT = ""
|
||||
|
@ -43,6 +46,15 @@ class Dialects(str, Enum):
|
|||
Doris = "doris"
|
||||
|
||||
|
||||
class NormalizationStrategy(str, AutoName):
|
||||
"""Specifies the strategy according to which identifiers should be normalized."""
|
||||
|
||||
LOWERCASE = auto() # Unquoted identifiers are lowercased
|
||||
UPPERCASE = auto() # Unquoted identifiers are uppercased
|
||||
CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
|
||||
CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||
|
||||
|
@ -106,26 +118,8 @@ class _Dialect(type):
|
|||
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
|
||||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
|
||||
dialect_properties = {
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(klass).items()
|
||||
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
|
||||
},
|
||||
"TOKENIZER_CLASS": klass.tokenizer_class,
|
||||
}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
dialect_properties["SELECT_KINDS"] = ()
|
||||
|
||||
# Pass required dialect properties to the tokenizer, parser and generator classes
|
||||
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
|
||||
for name, value in dialect_properties.items():
|
||||
if hasattr(subclass, name):
|
||||
setattr(subclass, name, value)
|
||||
|
||||
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
|
@ -133,8 +127,6 @@ class _Dialect(type):
|
|||
TokenType.SEMI,
|
||||
}
|
||||
|
||||
klass.generator_class.can_identify = klass.can_identify
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not the table alias comes after tablesample
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
|
||||
# Determines whether or not unquoted identifiers are resolved as uppercase
|
||||
# When set to None, it means that the dialect treats all identifiers as case-insensitive
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
|
||||
# Specifies the strategy according to which identifiers should be normalized.
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
|
||||
|
||||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
|
@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect):
|
|||
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
# Whether the behavior of a / b depends on the types of a and b.
|
||||
# False means a / b is always float division.
|
||||
# True means a / b is integer division if both a and b are integers.
|
||||
TYPED_DIVISION = False
|
||||
|
||||
# False means 1 / 0 throws an error.
|
||||
# True means 1 / 0 returns null.
|
||||
SAFE_DIVISION = False
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
|
||||
CONCAT_COALESCE = False
|
||||
|
||||
DATE_FORMAT = "'%Y-%m-%d'"
|
||||
DATEINT_FORMAT = "'%Y%m%d'"
|
||||
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
|
||||
|
@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect):
|
|||
# Such columns may be excluded from SELECT * queries, for example
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
|
||||
# Autofilled
|
||||
# --- Autofilled ---
|
||||
|
||||
tokenizer_class = Tokenizer
|
||||
parser_class = Parser
|
||||
generator_class = Generator
|
||||
|
@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self))
|
||||
# Delimiters for bit, hex and byte literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
BIT_END: t.Optional[str] = None
|
||||
HEX_START: t.Optional[str] = None
|
||||
HEX_END: t.Optional[str] = None
|
||||
BYTE_START: t.Optional[str] = None
|
||||
BYTE_END: t.Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
"""
|
||||
Look up a dialect in the global dialect registry and return it if it exists.
|
||||
|
||||
Args:
|
||||
dialect: The target dialect. If this is a string, it can be optionally followed by
|
||||
additional key-value pairs that are separated by commas and are used to specify
|
||||
dialect settings, such as whether the dialect's identifiers are case-sensitive.
|
||||
|
||||
Example:
|
||||
>>> dialect = dialect_class = get_or_raise("duckdb")
|
||||
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
|
||||
|
||||
Returns:
|
||||
The corresponding Dialect instance.
|
||||
"""
|
||||
|
||||
if not dialect:
|
||||
return cls
|
||||
return cls()
|
||||
if isinstance(dialect, _Dialect):
|
||||
return dialect
|
||||
return dialect()
|
||||
if isinstance(dialect, Dialect):
|
||||
return dialect.__class__
|
||||
return dialect
|
||||
if isinstance(dialect, str):
|
||||
try:
|
||||
dialect_name, *kv_pairs = dialect.split(",")
|
||||
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid dialect format: '{dialect}'. "
|
||||
"Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
|
||||
)
|
||||
|
||||
result = cls.get(dialect)
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect}'")
|
||||
result = cls.get(dialect_name.strip())
|
||||
if not result:
|
||||
raise ValueError(f"Unknown dialect '{dialect_name}'.")
|
||||
|
||||
return result
|
||||
return result(**kwargs)
|
||||
|
||||
raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
|
||||
|
||||
@classmethod
|
||||
def format_time(
|
||||
|
@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
normalization_strategy = kwargs.get("normalization_strategy")
|
||||
|
||||
if normalization_strategy is None:
|
||||
self.normalization_strategy = self.NORMALIZATION_STRATEGY
|
||||
else:
|
||||
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
# Does not currently take dialect state into account
|
||||
return type(self) == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Does not currently take dialect state into account
|
||||
return hash(type(self))
|
||||
|
||||
def normalize_identifier(self, expression: E) -> E:
|
||||
"""
|
||||
Normalizes an unquoted identifier to either lower or upper case, thus essentially
|
||||
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
|
||||
they will be normalized to lowercase regardless of being quoted or not.
|
||||
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
|
||||
|
||||
For example, an identifier like FoO would be resolved as foo in Postgres, because it
|
||||
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
|
||||
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
|
||||
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
|
||||
|
||||
There are also dialects like Spark, which are case-insensitive even when quotes are
|
||||
present, and dialects like MySQL, whose resolution rules match those employed by the
|
||||
underlying operating system, for example they may always be case-sensitive in Linux.
|
||||
|
||||
Finally, the normalization behavior of some engines can even be controlled through flags,
|
||||
like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
|
||||
|
||||
SQLGlot aims to understand and handle all of these different behaviors gracefully, so
|
||||
that it can analyze queries in the optimizer and successfully capture their semantics.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier) and (
|
||||
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
|
||||
and (
|
||||
not expression.quoted
|
||||
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
|
||||
)
|
||||
):
|
||||
expression.set(
|
||||
"this",
|
||||
expression.this.upper()
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else expression.this.lower(),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def case_sensitive(cls, text: str) -> bool:
|
||||
def case_sensitive(self, text: str) -> bool:
|
||||
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
|
||||
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
|
||||
if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
|
||||
return False
|
||||
|
||||
unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
|
||||
unsafe = (
|
||||
str.islower
|
||||
if self.normalization_strategy is NormalizationStrategy.UPPERCASE
|
||||
else str.isupper
|
||||
)
|
||||
return any(unsafe(char) for char in text)
|
||||
|
||||
@classmethod
|
||||
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
|
||||
def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
|
||||
"""Checks if text can be identified given an identify option.
|
||||
|
||||
Args:
|
||||
|
@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect):
|
|||
return True
|
||||
|
||||
if identify == "safe":
|
||||
return not cls.case_sensitive(text)
|
||||
return not self.case_sensitive(text)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
|
||||
def quote_identifier(self, expression: E, identify: bool = True) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
"quoted",
|
||||
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect):
|
|||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
self._tokenizer = self.tokenizer_class()
|
||||
self._tokenizer = self.tokenizer_class(dialect=self)
|
||||
return self._tokenizer
|
||||
|
||||
def parser(self, **opts) -> Parser:
|
||||
return self.parser_class(**opts)
|
||||
return self.parser_class(dialect=self, **opts)
|
||||
|
||||
def generator(self, **opts) -> Generator:
|
||||
return self.generator_class(**opts)
|
||||
return self.generator_class(dialect=self, **opts)
|
||||
|
||||
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
|
@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
|
||||
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
|
||||
|
||||
|
||||
|
@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex
|
|||
return self.func(name, expression.this, expression.expression)
|
||||
|
||||
return _arg_max_or_min_sql
|
||||
|
||||
|
||||
def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
|
||||
this = expression.this.copy()
|
||||
|
||||
return_type = expression.return_type
|
||||
if return_type.is_type(exp.DataType.Type.DATE):
|
||||
# If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
|
||||
# can truncate timestamp strings, because some dialects can't cast them to DATE
|
||||
this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
|
||||
|
||||
expression.this.replace(exp.cast(this, return_type))
|
||||
return expression
|
||||
|
||||
|
||||
def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
|
||||
def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
|
||||
if cast and isinstance(expression, exp.TsOrDsAdd):
|
||||
expression = ts_or_ds_add_cast(expression)
|
||||
|
||||
return self.func(
|
||||
name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
|
||||
)
|
||||
|
||||
return _delta_sql
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue