Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2945bcc4f7
commit
4d9376ba93
132 changed files with 55125 additions and 51576 deletions
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from enum import auto
|
||||
|
||||
from sqlglot.errors import TokenError
|
||||
from sqlglot.errors import SqlglotError, TokenError
|
||||
from sqlglot.helper import AutoName
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
@ -11,6 +12,19 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
try:
|
||||
from sqlglotrs import ( # type: ignore
|
||||
Tokenizer as RsTokenizer,
|
||||
TokenizerDialectSettings as RsTokenizerDialectSettings,
|
||||
TokenizerSettings as RsTokenizerSettings,
|
||||
TokenTypeSettings as RsTokenTypeSettings,
|
||||
)
|
||||
|
||||
USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1"
|
||||
except ImportError:
|
||||
USE_RS_TOKENIZER = False
|
||||
|
||||
|
||||
class TokenType(AutoName):
|
||||
L_PAREN = auto()
|
||||
R_PAREN = auto()
|
||||
|
@ -83,6 +97,7 @@ class TokenType(AutoName):
|
|||
NATIONAL_STRING = auto()
|
||||
RAW_STRING = auto()
|
||||
HEREDOC_STRING = auto()
|
||||
UNICODE_STRING = auto()
|
||||
|
||||
# types
|
||||
BIT = auto()
|
||||
|
@ -347,6 +362,10 @@ class TokenType(AutoName):
|
|||
TIMESTAMP_SNAPSHOT = auto()
|
||||
|
||||
|
||||
_ALL_TOKEN_TYPES = list(TokenType)
|
||||
_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)}
|
||||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
|
||||
|
||||
|
@ -432,6 +451,7 @@ class _Tokenizer(type):
|
|||
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
|
||||
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
|
||||
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
|
||||
**_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
|
||||
}
|
||||
|
||||
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||
|
@ -455,6 +475,46 @@ class _Tokenizer(type):
|
|||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||
)
|
||||
|
||||
if USE_RS_TOKENIZER:
|
||||
settings = RsTokenizerSettings(
|
||||
white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()},
|
||||
single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()},
|
||||
keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()},
|
||||
numeric_literals=klass.NUMERIC_LITERALS,
|
||||
identifiers=klass._IDENTIFIERS,
|
||||
identifier_escapes=klass._IDENTIFIER_ESCAPES,
|
||||
string_escapes=klass._STRING_ESCAPES,
|
||||
quotes=klass._QUOTES,
|
||||
format_strings={
|
||||
k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
|
||||
for k, (v1, v2) in klass._FORMAT_STRINGS.items()
|
||||
},
|
||||
has_bit_strings=bool(klass.BIT_STRINGS),
|
||||
has_hex_strings=bool(klass.HEX_STRINGS),
|
||||
comments=klass._COMMENTS,
|
||||
var_single_tokens=klass.VAR_SINGLE_TOKENS,
|
||||
commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS},
|
||||
command_prefix_tokens={
|
||||
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
|
||||
},
|
||||
)
|
||||
token_types = RsTokenTypeSettings(
|
||||
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
|
||||
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
|
||||
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
|
||||
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
|
||||
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
|
||||
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
|
||||
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
|
||||
parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER],
|
||||
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
|
||||
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
|
||||
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
|
||||
)
|
||||
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
|
||||
else:
|
||||
klass._RS_TOKENIZER = None
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
|
@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
_QUOTES: t.Dict[str, str] = {}
|
||||
_STRING_ESCAPES: t.Set[str] = set()
|
||||
_KEYWORD_TRIE: t.Dict = {}
|
||||
_RS_TOKENIZER: t.Optional[t.Any] = None
|
||||
|
||||
KEYWORDS: t.Dict[str, TokenType] = {
|
||||
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
|
||||
|
@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
# handle numeric literals like in hive (3L = BIGINT)
|
||||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||
ENCODE: t.Optional[str] = None
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/")]
|
||||
|
||||
|
@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"_rs_dialect_settings",
|
||||
)
|
||||
|
||||
def __init__(self, dialect: DialectType = None) -> None:
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
self.dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
if USE_RS_TOKENIZER:
|
||||
self._rs_dialect_settings = RsTokenizerDialectSettings(
|
||||
escape_sequences=self.dialect.ESCAPE_SEQUENCES,
|
||||
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||
if USE_RS_TOKENIZER:
|
||||
return self.tokenize_rs(sql)
|
||||
|
||||
self.reset()
|
||||
self.sql = sql
|
||||
self.size = len(sql)
|
||||
|
@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
# Ensures we don't count an extra line if we get a \r\n line break sequence
|
||||
if self._char == "\r" and self._peek == "\n":
|
||||
i = 2
|
||||
self._start += 1
|
||||
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
|
@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
raise TokenError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
else:
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
|
||||
self._add(token_type, text)
|
||||
return True
|
||||
|
@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
text += self.sql[current : self._current - 1]
|
||||
|
||||
return text
|
||||
|
||||
def tokenize_rs(self, sql: str) -> t.List[Token]:
|
||||
if not self._RS_TOKENIZER:
|
||||
raise SqlglotError("Rust tokenizer is not available")
|
||||
|
||||
try:
|
||||
tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings)
|
||||
for token in tokens:
|
||||
token.token_type = _ALL_TOKEN_TYPES[token.token_type_index]
|
||||
return tokens
|
||||
except Exception as e:
|
||||
raise TokenError(str(e))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue