1
0
Fork 0

Merging upstream version 10.5.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:05:06 +01:00
parent 3b8c9606bf
commit 599f59b0f8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
39 changed files with 786 additions and 133 deletions

View file

@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.5.2"
__version__ = "10.5.6"
pretty = False

View file

@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.teradata import Teradata
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL

View file

@ -165,6 +165,11 @@ class BigQuery(Dialect):
TokenType.TABLE,
}
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS, # type: ignore
TokenType.VALUES,
}
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore

View file

@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.errors import ParseError
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@ -72,6 +73,30 @@ class ClickHouse(Dialect):
return this
def _parse_position(self) -> exp.Expression:
this = super()._parse_position()
# clickhouse position args are swapped
substr = this.this
this.args["this"] = this.args.get("substr")
this.args["substr"] = substr
return this
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
return super()._parse_cte()
except ParseError:
# WITH <expression> AS <identifier>
self._retreat(index)
statement = self._parse_statement()
if statement and isinstance(statement.this, exp.Alias):
self.raise_error("Expected CTE to have alias")
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@ -110,3 +135,9 @@ class ClickHouse(Dialect):
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})"
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)

View file

@ -33,6 +33,7 @@ class Dialects(str, Enum):
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
TERADATA = "teradata"
class _Dialect(type):
@ -368,7 +369,7 @@ def locate_to_strposition(args):
)
def strposition_to_local_sql(self, expression):
def strposition_to_locate_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)

View file

@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
rename_func,
strposition_to_local_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
@ -297,7 +297,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,

View file

@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
strposition_to_local_sql,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -122,6 +122,8 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
@ -442,7 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
}
ROOT_PROPERTIES = {
@ -454,6 +456,10 @@ class MySQL(Dialect):
exp.LikeProperty,
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):

View file

@ -223,19 +223,15 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
@ -299,6 +295,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.Trim: trim_sql,

View file

@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_safe_divide_sql,
rename_func,
str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression):
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _concat_ws_sql(self, expression):
sep, *args = expression.expressions
sep = self.sql(sep)
if len(args) > 1:
return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
def _datatype_sql(self, expression):
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
@ -61,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})"
return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
def _encode_sql(self, expression):
@ -119,6 +110,38 @@ def _ensure_utf8(charset):
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
weight=seq_get(args, 1),
quantile=seq_get(args, 2),
accuracy=seq_get(args, 3),
)
if len(args) == 3:
return exp.ApproxQuantile(
this=seq_get(args, 0),
quantile=seq_get(args, 1),
accuracy=seq_get(args, 2),
)
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
hours=seq_get(args, 1),
minutes=seq_get(args, 2),
)
if len(args) == 2:
return exp.UnixToTime(
this=seq_get(args, 0),
zone=seq_get(args, 1),
)
return exp.UnixToTime.from_arg_list(args)
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@ -150,19 +173,25 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
@ -194,7 +223,6 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@ -209,12 +237,13 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
exp.StrPosition: str_position_sql,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
@ -233,6 +262,7 @@ class Presto(Dialect):
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
@ -21,6 +23,19 @@ class Redshift(Postgres):
"NVL": exp.Coalesce.from_arg_list,
}
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func)
if (
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
and this.expressions[0] == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
return this
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
@ -52,6 +67,10 @@ class Redshift(Postgres):
exp.DistStyleProperty,
}
WITH_PROPERTIES = {
exp.LikeProperty,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
@ -60,3 +79,57 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
}
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.
Note: If you have a lot of unions then this will result in a large number of recursive statements to
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
if not isinstance(expression.unnest().parent, exp.From):
return super().values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
selects = []
for i, row in enumerate(rows):
if i == 0:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
selects.append(exp.Select(expressions=row))
subquery_expression = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def renametable_sql(self, expression: exp.RenameTable) -> str:
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
expression = expression.copy()
target_table = expression.this
for arg in target_table.args:
if arg != "this":
target_table.set(arg, None)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
`TEXT` to `VARCHAR`.
"""
if expression.this == exp.DataType.Type.TEXT:
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
return super().datatype_sql(expression)

View file

@ -210,6 +210,7 @@ class Snowflake(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateAdd: rename_func("DATEADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@ -218,7 +219,7 @@ class Snowflake(Dialect):
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",

View file

@ -124,6 +124,7 @@ class Spark(Hive):
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)

View file

@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
def _fetch_sql(self, expression):
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
@ -30,6 +34,14 @@ def _group_concat_sql(self, expression):
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -71,6 +83,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: _date_add_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@ -78,8 +91,11 @@ class SQLite(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
exp.Fetch: _fetch_sql,
}
def transaction_sql(self, expression):

View file

@ -0,0 +1,87 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class Teradata(Dialect):
class Parser(parser.Parser):
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
"GRAPHIC_TO_LATIN",
"GRAPHIC_TO_UNICODE",
"GRAPHIC_TO_UNICODE_PadSpace",
"KANJI1_KanjiEBCDIC_TO_UNICODE",
"KANJI1_KanjiEUC_TO_UNICODE",
"KANJI1_KANJISJIS_TO_UNICODE",
"KANJI1_SBC_TO_UNICODE",
"KANJISJIS_TO_GRAPHIC",
"KANJISJIS_TO_LATIN",
"KANJISJIS_TO_UNICODE",
"LATIN_TO_GRAPHIC",
"LATIN_TO_KANJISJIS",
"LATIN_TO_UNICODE",
"LOCALE_TO_UNICODE",
"UNICODE_TO_GRAPHIC",
"UNICODE_TO_GRAPHIC_PadGraphic",
"UNICODE_TO_GRAPHIC_VarGraphic",
"UNICODE_TO_KANJI1_KanjiEBCDIC",
"UNICODE_TO_KANJI1_KanjiEUC",
"UNICODE_TO_KANJI1_KANJISJIS",
"UNICODE_TO_KANJI1_SBC",
"UNICODE_TO_KANJISJIS",
"UNICODE_TO_LATIN",
"UNICODE_TO_LOCALE",
"UNICODE_TO_UNICODE_FoldSpace",
"UNICODE_TO_UNICODE_Fullwidth",
"UNICODE_TO_UNICODE_Halfwidth",
"UNICODE_TO_UNICODE_NFC",
"UNICODE_TO_UNICODE_NFD",
"UNICODE_TO_UNICODE_NFKC",
"UNICODE_TO_UNICODE_NFKD",
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
def _parse_translate(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.USING):
self.raise_error("Expected USING in TRANSLATE")
if self._match_texts(self.CHARSET_TRANSLATORS):
charset_split = self._prev.text.split("_TO_")
to = self.expression(exp.CharacterSet, this=charset_split[1])
else:
self.raise_error("Expected a character set translator after USING in TRANSLATE")
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def _parse_update(self) -> exp.Expression:
return self.expression(
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"from": self._parse_from(),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
},
)
class Generator(generator.Generator):
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
from_sql = self.sql(expression, "from")
set_sql = self.expressions(expression, flat=True)
where_sql = self.sql(expression, "where")
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
return self.prepend_ctes(expression, sql)

View file

@ -243,28 +243,34 @@ class TSQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"DECLARE": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
"NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT,
"VARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@ -288,7 +294,7 @@ class TSQL(Dialect):
}
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
TABLE_PREFIX_TOKENS = {TokenType.HASH}
TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
def _parse_convert(self, strict):
to = self._parse_types()

View file

@ -653,6 +653,7 @@ class Create(Expression):
"statistics": False,
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
}
@ -770,6 +771,10 @@ class AlterColumn(Expression):
}
class RenameTable(Expression):
pass
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@ -804,7 +809,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
arg_types = {"this": True, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind):
@ -1266,7 +1271,7 @@ class Tuple(Expression):
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True):
def subquery(self, alias=None, copy=True) -> Subquery:
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@ -1460,6 +1465,7 @@ class Unnest(UDTF):
"expressions": True,
"ordinality": False,
"alias": False,
"offset": False,
}
@ -2126,6 +2132,7 @@ class DataType(Expression):
"this": True,
"expressions": False,
"nested": False,
"values": False,
}
class Type(AutoName):
@ -2134,6 +2141,8 @@ class DataType(Expression):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@ -2791,7 +2800,7 @@ class Day(Func):
class Decode(Func):
arg_types = {"this": True, "charset": True}
arg_types = {"this": True, "charset": True, "replace": False}
class DiToDate(Func):
@ -2815,7 +2824,7 @@ class Floor(Func):
class Greatest(Func):
arg_types = {"this": True, "expressions": True}
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -2861,7 +2870,7 @@ class JSONBExtractScalar(JSONExtract):
class Least(Func):
arg_types = {"this": True, "expressions": True}
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -2904,7 +2913,7 @@ class Lower(Func):
class Map(Func):
arg_types = {"keys": True, "values": True}
arg_types = {"keys": False, "values": False}
class VarMap(Func):
@ -2923,11 +2932,11 @@ class Matches(Func):
class Max(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class Min(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class Month(Func):
@ -2962,7 +2971,7 @@ class QuantileIf(AggFunc):
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
class ReadCSV(Func):
@ -3022,7 +3031,12 @@ class Substring(Func):
class StrPosition(Func):
arg_types = {"substr": True, "this": True, "position": False}
arg_types = {
"this": True,
"substr": True,
"position": False,
"instance": False,
}
class StrToDate(Func):
@ -3129,8 +3143,10 @@ class UnixToStr(Func):
arg_types = {"this": True, "format": False}
# https://prestodb.io/docs/current/functions/datetime.html
# presto has weird zone/hours/minutes
class UnixToTime(Func):
arg_types = {"this": True, "scale": False}
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
@ -3684,6 +3700,16 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@t.overload
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@ -3860,6 +3886,26 @@ def values(
)
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
"""Build ALTER TABLE... RENAME... expression
Args:
old_name: The old name of the table
new_name: The new name of the table
Returns:
Alter table expression
"""
old_table = to_table(old_name)
new_table = to_table(new_name)
return AlterTable(
this=old_table,
actions=[
RenameTable(this=new_table),
],
)
def convert(value) -> Expression:
"""Convert a python value into an expression object.

View file

@ -82,6 +82,8 @@ class Generator:
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@ -105,6 +107,7 @@ class Generator:
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"time_mapping",
@ -211,6 +214,8 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def unsupported(self, message: str) -> None:
@ -401,7 +406,17 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
increment = f"INCREMENT BY {increment}" if increment else ""
sequence_opts = ""
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
return (
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
)
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -475,10 +490,13 @@ class Generator:
materialized,
)
)
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@ -517,13 +535,19 @@ class Generator:
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
if interior:
nested = (
f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("nested")
else f"({interior})"
)
return f"{type_sql}{nested}"
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = (
f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
)
else:
nested = f"({interior})"
return f"{type_sql}{nested}{values}"
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@ -622,10 +646,14 @@ class Generator:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
def properties(
self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
return (
f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
)
return ""
def with_properties(self, properties: exp.Properties) -> str:
@ -763,14 +791,15 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias
if self.WRAP_DERIVED_VALUES:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
values = (
f"({values})"
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
else values
)
return f"{values} AS {alias}" if alias else values
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
@ -868,6 +897,8 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
return text
@ -1036,7 +1067,9 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}"
offset = expression.args.get("offset")
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
@ -1132,15 +1165,14 @@ class Generator:
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return f"LTRIM({target})"
return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
elif trim_type == "TRAILING":
return f"RTRIM({target})"
return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
else:
return f"TRIM({target})"
return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
@ -1317,6 +1349,10 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
@ -1326,7 +1362,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn):
elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")

View file

@ -52,7 +52,10 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (sequence): sequence of optimizer rules to use
rules (sequence): sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression

View file

@ -1,7 +1,7 @@
import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -382,7 +382,7 @@ class _Resolver:
try:
return self.schema.column_names(source, only_visible)
except Exception as e:
raise OptimizeError(str(e)) from e
raise SchemaError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names

View file

@ -107,6 +107,8 @@ class Parser(metaclass=_Parser):
TokenType.VARCHAR,
TokenType.NVARCHAR,
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
@ -233,6 +235,7 @@ class Parser(metaclass=_Parser):
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VIEW,
TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES,
@ -252,6 +255,7 @@ class Parser(metaclass=_Parser):
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
TokenType.COMMAND,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
@ -552,7 +556,7 @@ class Parser(metaclass=_Parser):
TokenType.IF: lambda self: self._parse_if(),
}
FUNCTION_PARSERS = {
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(),
@ -937,6 +941,7 @@ class Parser(metaclass=_Parser):
statistics = None
no_primary_index = None
indexes = None
no_schema_binding = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
@ -975,6 +980,9 @@ class Parser(metaclass=_Parser):
break
else:
indexes.append(index)
elif create_token.token_type == TokenType.VIEW:
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
return self.expression(
exp.Create,
@ -993,6 +1001,7 @@ class Parser(metaclass=_Parser):
statistics=statistics,
no_primary_index=no_primary_index,
indexes=indexes,
no_schema_binding=no_schema_binding,
)
def _parse_property(self) -> t.Optional[exp.Expression]:
@ -1246,8 +1255,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self) -> exp.Expression:
expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
@ -1313,19 +1328,9 @@ class Parser(metaclass=_Parser):
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression(
exp.Values,
expressions=expressions,
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
else:
@ -1612,13 +1617,12 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
if self._match(TokenType.WITH):
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set(
"hints",
self._parse_wrapped_csv(
lambda: self._parse_function() or self._parse_var(any_token=True)
),
self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
)
self._match_r_paren()
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@ -1643,8 +1647,17 @@ class Parser(metaclass=_Parser):
alias.set("columns", [alias.this])
alias.set("this", None)
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_conjunction()
return self.expression(
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
exp.Unnest,
expressions=expressions,
ordinality=ordinality,
alias=alias,
offset=offset,
)
def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
@ -1999,7 +2012,7 @@ class Parser(metaclass=_Parser):
this = self._parse_column()
if type_token:
if this:
if this and not isinstance(this, exp.Star):
return self.expression(exp.Cast, this=this, to=type_token)
if not type_token.args.get("expressions"):
self._retreat(index)
@ -2050,6 +2063,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -2059,6 +2073,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)):
values = self._parse_csv(self._parse_conjunction)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
@ -2097,9 +2115,13 @@ class Parser(metaclass=_Parser):
this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions,
nested=nested,
values=values,
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
if self._curr and self._curr.token_type in self.TYPE_TOKENS:
return self._parse_types()
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
@ -2412,6 +2434,14 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ALWAYS)
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
kind.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
kind.set("increment", self._parse_bitwise())
self._match_r_paren()
else:
return this
@ -2619,8 +2649,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.IN):
args.append(self._parse_bitwise())
# Note: we're parsing in order needle, haystack, position
this = exp.StrPosition.from_arg_list(args)
this = exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
self.validate_expression(this, args)
return this
@ -2999,6 +3033,8 @@ class Parser(metaclass=_Parser):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)

View file

@ -82,6 +82,8 @@ class TokenType(AutoName):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@ -434,6 +436,8 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
_ESCAPES: t.Set[str] = set()
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
@ -461,6 +465,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"ALL": TokenType.ALL,
"ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@ -472,6 +477,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
@ -521,9 +527,11 @@ class Tokenizer(metaclass=_Tokenizer):
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
"GENERATED": TokenType.GENERATED,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
@ -746,7 +754,7 @@ class Tokenizer(metaclass=_Tokenizer):
)
def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES # type: ignore
self._replace_backslash = "\\" in self._ESCAPES
self.reset()
def reset(self) -> None:
@ -771,7 +779,10 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset()
self.sql = sql
self.size = len(sql)
self._scan()
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end:
self._start = self._current
self._advance()
@ -792,7 +803,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_identifier(identifier_end)
else:
self._scan_keywords()
return self.tokens
if until and until():
break
def _chars(self, size: int) -> str:
if size == 1:
@ -832,11 +845,13 @@ class Tokenizer(metaclass=_Tokenizer):
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
start = self._current
tokens = len(self.tokens)
self._scan(lambda: self._peek == ";")
self.tokens = self.tokens[:tokens]
text = self.sql[start : self._current].strip()
if text:
self._add(TokenType.STRING, text)
def _scan_keywords(self) -> None:
size = 0
@ -947,7 +962,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.isidentifier(): # type: ignore
number_text = self._text
literal = []
while self._peek.isidentifier(): # type: ignore
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
@ -1063,8 +1079,12 @@ class Tokenizer(metaclass=_Tokenizer):
delim_size = len(delimiter)
while True:
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter
if (
self._char in self._ESCAPES
and self._peek
and (self._peek == delimiter or self._peek in self._ESCAPES)
):
text += self._peek
self._advance(2)
else:
if self._chars(delim_size) == delimiter: