1
0
Fork 0

Adding upstream version 10.5.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:04:17 +01:00
parent b97d49f611
commit 556602e7d9
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
39 changed files with 786 additions and 133 deletions

View file

@ -462,7 +462,7 @@ make check # Set SKIP_INTEGRATION=1 to skip integration tests
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76E-5 (0.080) |
| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |

View file

@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.5.2"
__version__ = "10.5.6"
pretty = False

View file

@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.teradata import Teradata
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL

View file

@ -165,6 +165,11 @@ class BigQuery(Dialect):
TokenType.TABLE,
}
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS, # type: ignore
TokenType.VALUES,
}
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore

View file

@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.errors import ParseError
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@ -72,6 +73,30 @@ class ClickHouse(Dialect):
return this
def _parse_position(self) -> exp.Expression:
this = super()._parse_position()
# clickhouse position args are swapped
substr = this.this
this.args["this"] = this.args.get("substr")
this.args["substr"] = substr
return this
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
return super()._parse_cte()
except ParseError:
# WITH <expression> AS <identifier>
self._retreat(index)
statement = self._parse_statement()
if statement and isinstance(statement.this, exp.Alias):
self.raise_error("Expected CTE to have alias")
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@ -110,3 +135,9 @@ class ClickHouse(Dialect):
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})"
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)

View file

@ -33,6 +33,7 @@ class Dialects(str, Enum):
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
TERADATA = "teradata"
class _Dialect(type):
@ -368,7 +369,7 @@ def locate_to_strposition(args):
)
def strposition_to_local_sql(self, expression):
def strposition_to_locate_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)

View file

@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
rename_func,
strposition_to_local_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
@ -297,7 +297,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,

View file

@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
strposition_to_local_sql,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -122,6 +122,8 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
@ -442,7 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
exp.StrPosition: strposition_to_locate_sql,
}
ROOT_PROPERTIES = {
@ -454,6 +456,10 @@ class MySQL(Dialect):
exp.LikeProperty,
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):

View file

@ -223,19 +223,15 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
@ -299,6 +295,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.Trim: trim_sql,

View file

@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_safe_divide_sql,
rename_func,
str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression):
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _concat_ws_sql(self, expression):
sep, *args = expression.expressions
sep = self.sql(sep)
if len(args) > 1:
return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
def _datatype_sql(self, expression):
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
@ -61,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})"
return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
def _encode_sql(self, expression):
@ -119,6 +110,38 @@ def _ensure_utf8(charset):
raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
weight=seq_get(args, 1),
quantile=seq_get(args, 2),
accuracy=seq_get(args, 3),
)
if len(args) == 3:
return exp.ApproxQuantile(
this=seq_get(args, 0),
quantile=seq_get(args, 1),
accuracy=seq_get(args, 2),
)
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
hours=seq_get(args, 1),
minutes=seq_get(args, 2),
)
if len(args) == 2:
return exp.UnixToTime(
this=seq_get(args, 0),
zone=seq_get(args, 1),
)
return exp.UnixToTime.from_arg_list(args)
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@ -150,19 +173,25 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
@ -194,7 +223,6 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@ -209,12 +237,13 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
exp.StrPosition: str_position_sql,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
@ -233,6 +262,7 @@ class Presto(Dialect):
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
@ -21,6 +23,19 @@ class Redshift(Postgres):
"NVL": exp.Coalesce.from_arg_list,
}
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func)
if (
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
and this.expressions[0] == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
return this
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
@ -52,6 +67,10 @@ class Redshift(Postgres):
exp.DistStyleProperty,
}
WITH_PROPERTIES = {
exp.LikeProperty,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
@ -60,3 +79,57 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
}
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.
Note: If you have a lot of unions then this will result in a large number of recursive statements to
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
if not isinstance(expression.unnest().parent, exp.From):
return super().values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
selects = []
for i, row in enumerate(rows):
if i == 0:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
selects.append(exp.Select(expressions=row))
subquery_expression = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def renametable_sql(self, expression: exp.RenameTable) -> str:
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
expression = expression.copy()
target_table = expression.this
for arg in target_table.args:
if arg != "this":
target_table.set(arg, None)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
`TEXT` to `VARCHAR`.
"""
if expression.this == exp.DataType.Type.TEXT:
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
return super().datatype_sql(expression)

View file

@ -210,6 +210,7 @@ class Snowflake(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateAdd: rename_func("DATEADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@ -218,7 +219,7 @@ class Snowflake(Dialect):
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"),
exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",

View file

@ -124,6 +124,7 @@ class Spark(Hive):
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)

View file

@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
def _fetch_sql(self, expression):
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
@ -30,6 +34,14 @@ def _group_concat_sql(self, expression):
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -71,6 +83,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: _date_add_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@ -78,8 +91,11 @@ class SQLite(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
exp.Fetch: _fetch_sql,
}
def transaction_sql(self, expression):

View file

@ -0,0 +1,87 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class Teradata(Dialect):
class Parser(parser.Parser):
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
"GRAPHIC_TO_LATIN",
"GRAPHIC_TO_UNICODE",
"GRAPHIC_TO_UNICODE_PadSpace",
"KANJI1_KanjiEBCDIC_TO_UNICODE",
"KANJI1_KanjiEUC_TO_UNICODE",
"KANJI1_KANJISJIS_TO_UNICODE",
"KANJI1_SBC_TO_UNICODE",
"KANJISJIS_TO_GRAPHIC",
"KANJISJIS_TO_LATIN",
"KANJISJIS_TO_UNICODE",
"LATIN_TO_GRAPHIC",
"LATIN_TO_KANJISJIS",
"LATIN_TO_UNICODE",
"LOCALE_TO_UNICODE",
"UNICODE_TO_GRAPHIC",
"UNICODE_TO_GRAPHIC_PadGraphic",
"UNICODE_TO_GRAPHIC_VarGraphic",
"UNICODE_TO_KANJI1_KanjiEBCDIC",
"UNICODE_TO_KANJI1_KanjiEUC",
"UNICODE_TO_KANJI1_KANJISJIS",
"UNICODE_TO_KANJI1_SBC",
"UNICODE_TO_KANJISJIS",
"UNICODE_TO_LATIN",
"UNICODE_TO_LOCALE",
"UNICODE_TO_UNICODE_FoldSpace",
"UNICODE_TO_UNICODE_Fullwidth",
"UNICODE_TO_UNICODE_Halfwidth",
"UNICODE_TO_UNICODE_NFC",
"UNICODE_TO_UNICODE_NFD",
"UNICODE_TO_UNICODE_NFKC",
"UNICODE_TO_UNICODE_NFKD",
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
def _parse_translate(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.USING):
self.raise_error("Expected USING in TRANSLATE")
if self._match_texts(self.CHARSET_TRANSLATORS):
charset_split = self._prev.text.split("_TO_")
to = self.expression(exp.CharacterSet, this=charset_split[1])
else:
self.raise_error("Expected a character set translator after USING in TRANSLATE")
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def _parse_update(self) -> exp.Expression:
return self.expression(
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"from": self._parse_from(),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
},
)
class Generator(generator.Generator):
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
from_sql = self.sql(expression, "from")
set_sql = self.expressions(expression, flat=True)
where_sql = self.sql(expression, "where")
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
return self.prepend_ctes(expression, sql)

View file

@ -243,28 +243,34 @@ class TSQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"DECLARE": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
"NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT,
"VARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@ -288,7 +294,7 @@ class TSQL(Dialect):
}
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
TABLE_PREFIX_TOKENS = {TokenType.HASH}
TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
def _parse_convert(self, strict):
to = self._parse_types()

View file

@ -653,6 +653,7 @@ class Create(Expression):
"statistics": False,
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
}
@ -770,6 +771,10 @@ class AlterColumn(Expression):
}
class RenameTable(Expression):
pass
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@ -804,7 +809,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
arg_types = {"this": True, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind):
@ -1266,7 +1271,7 @@ class Tuple(Expression):
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True):
def subquery(self, alias=None, copy=True) -> Subquery:
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@ -1460,6 +1465,7 @@ class Unnest(UDTF):
"expressions": True,
"ordinality": False,
"alias": False,
"offset": False,
}
@ -2126,6 +2132,7 @@ class DataType(Expression):
"this": True,
"expressions": False,
"nested": False,
"values": False,
}
class Type(AutoName):
@ -2134,6 +2141,8 @@ class DataType(Expression):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@ -2791,7 +2800,7 @@ class Day(Func):
class Decode(Func):
arg_types = {"this": True, "charset": True}
arg_types = {"this": True, "charset": True, "replace": False}
class DiToDate(Func):
@ -2815,7 +2824,7 @@ class Floor(Func):
class Greatest(Func):
arg_types = {"this": True, "expressions": True}
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -2861,7 +2870,7 @@ class JSONBExtractScalar(JSONExtract):
class Least(Func):
arg_types = {"this": True, "expressions": True}
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -2904,7 +2913,7 @@ class Lower(Func):
class Map(Func):
arg_types = {"keys": True, "values": True}
arg_types = {"keys": False, "values": False}
class VarMap(Func):
@ -2923,11 +2932,11 @@ class Matches(Func):
class Max(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class Min(AggFunc):
pass
arg_types = {"this": True, "expression": False}
class Month(Func):
@ -2962,7 +2971,7 @@ class QuantileIf(AggFunc):
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
class ReadCSV(Func):
@ -3022,7 +3031,12 @@ class Substring(Func):
class StrPosition(Func):
arg_types = {"substr": True, "this": True, "position": False}
arg_types = {
"this": True,
"substr": True,
"position": False,
"instance": False,
}
class StrToDate(Func):
@ -3129,8 +3143,10 @@ class UnixToStr(Func):
arg_types = {"this": True, "format": False}
# https://prestodb.io/docs/current/functions/datetime.html
# presto has weird zone/hours/minutes
class UnixToTime(Func):
arg_types = {"this": True, "scale": False}
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
@ -3684,6 +3700,16 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@t.overload
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@ -3860,6 +3886,26 @@ def values(
)
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
"""Build ALTER TABLE... RENAME... expression
Args:
old_name: The old name of the table
new_name: The new name of the table
Returns:
Alter table expression
"""
old_table = to_table(old_name)
new_table = to_table(new_name)
return AlterTable(
this=old_table,
actions=[
RenameTable(this=new_table),
],
)
def convert(value) -> Expression:
"""Convert a python value into an expression object.

View file

@ -82,6 +82,8 @@ class Generator:
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@ -105,6 +107,7 @@ class Generator:
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"time_mapping",
@ -211,6 +214,8 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def unsupported(self, message: str) -> None:
@ -401,7 +406,17 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
increment = f"INCREMENT BY {increment}" if increment else ""
sequence_opts = ""
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
return (
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
)
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -475,10 +490,13 @@ class Generator:
materialized,
)
)
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@ -517,13 +535,19 @@ class Generator:
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
if interior:
nested = (
f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("nested")
else f"({interior})"
)
return f"{type_sql}{nested}"
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = (
f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
)
else:
nested = f"({interior})"
return f"{type_sql}{nested}{values}"
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@ -622,10 +646,14 @@ class Generator:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
def properties(
self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
return (
f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
)
return ""
def with_properties(self, properties: exp.Properties) -> str:
@ -763,14 +791,15 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
return f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias
if self.WRAP_DERIVED_VALUES:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
alias = self.sql(expression, "alias")
values = f"VALUES{self.seg('')}{args}"
values = (
f"({values})"
if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
else values
)
return f"{values} AS {alias}" if alias else values
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
@ -868,6 +897,8 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
return text
@ -1036,7 +1067,9 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}"
offset = expression.args.get("offset")
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
@ -1132,15 +1165,14 @@ class Generator:
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return f"LTRIM({target})"
return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
elif trim_type == "TRAILING":
return f"RTRIM({target})"
return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
else:
return f"TRIM({target})"
return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
@ -1317,6 +1349,10 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
@ -1326,7 +1362,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn):
elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")

View file

@ -52,7 +52,10 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (sequence): sequence of optimizer rules to use
rules (sequence): sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression

View file

@ -1,7 +1,7 @@
import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -382,7 +382,7 @@ class _Resolver:
try:
return self.schema.column_names(source, only_visible)
except Exception as e:
raise OptimizeError(str(e)) from e
raise SchemaError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names

View file

@ -107,6 +107,8 @@ class Parser(metaclass=_Parser):
TokenType.VARCHAR,
TokenType.NVARCHAR,
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
@ -233,6 +235,7 @@ class Parser(metaclass=_Parser):
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VIEW,
TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES,
@ -252,6 +255,7 @@ class Parser(metaclass=_Parser):
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
TokenType.COMMAND,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
@ -552,7 +556,7 @@ class Parser(metaclass=_Parser):
TokenType.IF: lambda self: self._parse_if(),
}
FUNCTION_PARSERS = {
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(),
@ -937,6 +941,7 @@ class Parser(metaclass=_Parser):
statistics = None
no_primary_index = None
indexes = None
no_schema_binding = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
@ -975,6 +980,9 @@ class Parser(metaclass=_Parser):
break
else:
indexes.append(index)
elif create_token.token_type == TokenType.VIEW:
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
return self.expression(
exp.Create,
@ -993,6 +1001,7 @@ class Parser(metaclass=_Parser):
statistics=statistics,
no_primary_index=no_primary_index,
indexes=indexes,
no_schema_binding=no_schema_binding,
)
def _parse_property(self) -> t.Optional[exp.Expression]:
@ -1246,8 +1255,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self) -> exp.Expression:
expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
@ -1313,19 +1328,9 @@ class Parser(metaclass=_Parser):
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression(
exp.Values,
expressions=expressions,
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
else:
@ -1612,13 +1617,12 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
if self._match(TokenType.WITH):
if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set(
"hints",
self._parse_wrapped_csv(
lambda: self._parse_function() or self._parse_var(any_token=True)
),
self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
)
self._match_r_paren()
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@ -1643,8 +1647,17 @@ class Parser(metaclass=_Parser):
alias.set("columns", [alias.this])
alias.set("this", None)
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_conjunction()
return self.expression(
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
exp.Unnest,
expressions=expressions,
ordinality=ordinality,
alias=alias,
offset=offset,
)
def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
@ -1999,7 +2012,7 @@ class Parser(metaclass=_Parser):
this = self._parse_column()
if type_token:
if this:
if this and not isinstance(this, exp.Star):
return self.expression(exp.Cast, this=this, to=type_token)
if not type_token.args.get("expressions"):
self._retreat(index)
@ -2050,6 +2063,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@ -2059,6 +2073,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)):
values = self._parse_csv(self._parse_conjunction)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
@ -2097,9 +2115,13 @@ class Parser(metaclass=_Parser):
this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions,
nested=nested,
values=values,
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
if self._curr and self._curr.token_type in self.TYPE_TOKENS:
return self._parse_types()
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
@ -2412,6 +2434,14 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ALWAYS)
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
kind.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
kind.set("increment", self._parse_bitwise())
self._match_r_paren()
else:
return this
@ -2619,8 +2649,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.IN):
args.append(self._parse_bitwise())
# Note: we're parsing in order needle, haystack, position
this = exp.StrPosition.from_arg_list(args)
this = exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
self.validate_expression(this, args)
return this
@ -2999,6 +3033,8 @@ class Parser(metaclass=_Parser):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)

View file

@ -82,6 +82,8 @@ class TokenType(AutoName):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@ -434,6 +436,8 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
_ESCAPES: t.Set[str] = set()
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
@ -461,6 +465,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"ALL": TokenType.ALL,
"ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@ -472,6 +477,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
@ -521,9 +527,11 @@ class Tokenizer(metaclass=_Tokenizer):
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
"GENERATED": TokenType.GENERATED,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
@ -746,7 +754,7 @@ class Tokenizer(metaclass=_Tokenizer):
)
def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES # type: ignore
self._replace_backslash = "\\" in self._ESCAPES
self.reset()
def reset(self) -> None:
@ -771,7 +779,10 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset()
self.sql = sql
self.size = len(sql)
self._scan()
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end:
self._start = self._current
self._advance()
@ -792,7 +803,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_identifier(identifier_end)
else:
self._scan_keywords()
return self.tokens
if until and until():
break
def _chars(self, size: int) -> str:
if size == 1:
@ -832,11 +845,13 @@ class Tokenizer(metaclass=_Tokenizer):
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
start = self._current
tokens = len(self.tokens)
self._scan(lambda: self._peek == ";")
self.tokens = self.tokens[:tokens]
text = self.sql[start : self._current].strip()
if text:
self._add(TokenType.STRING, text)
def _scan_keywords(self) -> None:
size = 0
@ -947,7 +962,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.isidentifier(): # type: ignore
number_text = self._text
literal = []
while self._peek.isidentifier(): # type: ignore
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
@ -1063,8 +1079,12 @@ class Tokenizer(metaclass=_Tokenizer):
delim_size = len(delimiter)
while True:
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter
if (
self._char in self._ESCAPES
and self._peek
and (self._peek == delimiter or self._peek in self._ESCAPES)
):
text += self._peek
self._advance(2)
else:
if self._chars(delim_size) == delimiter:

View file

@ -6,6 +6,8 @@ class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_all(
"REGEXP_CONTAINS('foo', '.*')",
read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
@ -41,6 +43,15 @@ class TestBigQuery(Validator):
"spark": r"'/\\*.*\\*/'",
},
)
self.validate_all(
r"'\\'",
write={
"bigquery": r"'\\'",
"duckdb": r"'\'",
"presto": r"'\'",
"hive": r"'\\'",
},
)
self.validate_all(
R'R"""/\*.*\*/"""',
write={

View file

@ -17,6 +17,7 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
self.validate_identity("position(a, b)")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -47,3 +48,9 @@ class TestClickhouse(Validator):
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
},
)
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")

View file

@ -14,7 +14,7 @@ class Validator(unittest.TestCase):
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
def validate_all(self, sql, read=None, write=None, pretty=False, identify=False):
"""
Validate that:
1. Everything in `read` transpiles to `sql`
@ -32,7 +32,10 @@ class Validator(unittest.TestCase):
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
self.dialect,
unsupported_level=ErrorLevel.IGNORE,
pretty=pretty,
identify=identify,
),
sql,
)
@ -48,6 +51,7 @@ class Validator(unittest.TestCase):
write_dialect,
unsupported_level=ErrorLevel.IGNORE,
pretty=pretty,
identify=identify,
),
write_sql,
)
@ -76,7 +80,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)",
"redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
@ -155,7 +159,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)",
"redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
@ -344,6 +348,7 @@ class TestDialect(Validator):
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
"sqlite": "'2020-01-01'",
},
)
self.validate_all(
@ -373,7 +378,7 @@ class TestDialect(Validator):
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
"redshift": "CAST(x AS TEXT)",
"redshift": "CAST(x AS VARCHAR(MAX))",
},
)
self.validate_all(
@ -488,7 +493,9 @@ class TestDialect(Validator):
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"postgres": "x + INTERVAL '1' 'day'",
"presto": "DATE_ADD('day', 1, x)",
"snowflake": "DATEADD(x, 1, 'day')",
"spark": "DATE_ADD(x, 1)",
"sqlite": "DATE(x, '1 day')",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
"tsql": "DATEADD(day, 1, x)",
},
@ -594,6 +601,7 @@ class TestDialect(Validator):
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
"sqlite": "x",
},
)
self.validate_all(
@ -955,7 +963,7 @@ class TestDialect(Validator):
},
)
self.validate_all(
"STR_POSITION('a', x)",
"STR_POSITION(x, 'a')",
write={
"drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
@ -971,7 +979,7 @@ class TestDialect(Validator):
"POSITION('a', x, 3)",
write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(x, 'a', 3)",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
@ -982,9 +990,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', 'a', 'b')",
write={
"duckdb": "CONCAT_WS('-', 'a', 'b')",
"presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')",
"presto": "CONCAT_WS('-', 'a', 'b')",
"hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')",
"trino": "CONCAT_WS('-', 'a', 'b')",
},
)
@ -992,9 +1001,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', x)",
write={
"duckdb": "CONCAT_WS('-', x)",
"presto": "ARRAY_JOIN(x, '-')",
"hive": "CONCAT_WS('-', x)",
"presto": "CONCAT_WS('-', x)",
"spark": "CONCAT_WS('-', x)",
"trino": "CONCAT_WS('-', x)",
},
)
self.validate_all(
@ -1118,6 +1128,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY",
write={
"sqlite": "SELECT x FROM y LIMIT 3 OFFSET 10",
"oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
},
)
@ -1197,7 +1208,7 @@ class TestDialect(Validator):
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
"redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))",
"redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 VARCHAR(MAX), c2 VARCHAR(1024))",
},
)

View file

@ -356,6 +356,30 @@ class TestHive(Validator):
"spark": "SELECT a_b AS 1_a FROM test_table",
},
)
self.validate_all(
"SELECT 1a_1a FROM test_a",
write={
"spark": "SELECT 1a_1a FROM test_a",
},
)
self.validate_all(
"SELECT 1a AS 1a_1a FROM test_a",
write={
"spark": "SELECT 1a AS 1a_1a FROM test_a",
},
)
self.validate_all(
"CREATE TABLE test_table (1a STRING)",
write={
"spark": "CREATE TABLE test_table (1a STRING)",
},
)
self.validate_all(
"CREATE TABLE test_table2 (1a_1a STRING)",
write={
"spark": "CREATE TABLE test_table2 (1a_1a STRING)",
},
)
self.validate_all(
"PERCENTILE(x, 0.5)",
write={
@ -420,7 +444,7 @@ class TestHive(Validator):
"LOCATE('a', x, 3)",
write={
"duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(x, 'a', 3)",
"hive": "LOCATE('a', x, 3)",
"spark": "LOCATE('a', x, 3)",
},

View file

@ -65,6 +65,17 @@ class TestMySQL(Validator):
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
self.validate_identity("SELECT SCHEMA()")
def test_types(self):
self.validate_all(
"CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
read={
"mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
},
write={
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
},
)
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")

View file

@ -46,14 +46,6 @@ class TestPostgres(Validator):
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
)
with self.assertRaises(ParseError):
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")

View file

@ -152,6 +152,10 @@ class TestPresto(Validator):
"spark": "FROM_UNIXTIME(x)",
},
)
self.validate_identity("FROM_UNIXTIME(a, b)")
self.validate_identity("FROM_UNIXTIME(a, b, c)")
self.validate_identity("TRIM(a, b)")
self.validate_identity("VAR_POP(a)")
self.validate_all(
"TO_UNIXTIME(x)",
write={
@ -302,6 +306,7 @@ class TestPresto(Validator):
)
def test_presto(self):
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
self.validate_all(
'SELECT a."b" FROM "foo"',
write={
@ -443,8 +448,10 @@ class TestPresto(Validator):
"spark": UnsupportedError,
},
)
self.validate_identity("SELECT * FROM (VALUES (1))")
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
def test_encode_decode(self):
self.validate_all(
@ -459,6 +466,12 @@ class TestPresto(Validator):
"spark": "DECODE(x, 'utf-8')",
},
)
self.validate_all(
"FROM_UTF8(x, y)",
write={
"presto": "FROM_UTF8(x, y)",
},
)
self.validate_all(
"ENCODE(x, 'utf-8')",
write={

View file

@ -89,7 +89,9 @@ class TestRedshift(Validator):
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
self.validate_identity(
"CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL"
)
self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
@ -102,3 +104,81 @@ class TestRedshift(Validator):
self.validate_identity(
"CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
)
def test_values(self):
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t",
},
)
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
},
)
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t",
},
)
self.validate_all(
"INSERT INTO t(a) VALUES (1), (2), (3)",
write={
"redshift": "INSERT INTO t (a) VALUES (1), (2), (3)",
},
)
self.validate_all(
"INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
write={
"redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
},
)
self.validate_all(
"INSERT INTO t(a, b) VALUES (1, 2), (3, 4)",
write={
"redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)",
},
)
def test_create_table_like(self):
self.validate_all(
"CREATE TABLE t1 LIKE t2",
write={
"redshift": "CREATE TABLE t1 (LIKE t2)",
},
)
self.validate_all(
"CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
write={
"redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
},
)
def test_rename_table(self):
self.validate_all(
"ALTER TABLE db.t1 RENAME TO db.t2",
write={
"spark": "ALTER TABLE db.t1 RENAME TO db.t2",
"redshift": "ALTER TABLE db.t1 RENAME TO t2",
},
)
def test_varchar_max(self):
self.validate_all(
"CREATE TABLE TEST (cola VARCHAR(MAX))",
write={
"redshift": 'CREATE TABLE "TEST" ("cola" VARCHAR(MAX))',
},
identify=True,
)
def test_no_schema_binding(self):
self.validate_all(
"CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
write={
"redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
},
)

View file

@ -307,5 +307,12 @@ TBLPROPERTIES (
def test_iif(self):
self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
"SELECT IIF(cond, 'True', 'False')",
write={"spark": "SELECT IF(cond, 'True', 'False')"},
)
def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)

View file

@ -0,0 +1,23 @@
from tests.dialects.test_dialect import Validator
class TestTeradata(Validator):
dialect = "teradata"
def test_translate(self):
self.validate_all(
"TRANSLATE(x USING LATIN_TO_UNICODE)",
write={
"teradata": "CAST(x AS CHAR CHARACTER SET UNICODE)",
},
)
self.validate_identity("CAST(x AS CHAR CHARACTER SET UNICODE)")
def test_update(self):
self.validate_all(
"UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
write={
"teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
},
)

View file

@ -5,6 +5,13 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
self.validate_identity("PRINT @TestVariable")
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)")
self.validate_identity(
"SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID"
)
self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo")

View file

@ -59,6 +59,8 @@ map.x
SELECT call.x
a.b.INT(1.234)
INT(x / 100)
time * 100
int * 100
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
@ -69,6 +71,11 @@ x IS TRUE
x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
MAP()
GREATEST(x)
LEAST(y)
MAX(a, b)
MIN(a, b)
time
zone
ARRAY<TEXT>
@ -133,6 +140,7 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
SET x = ';'
COMMIT
USE db
NOT 1
@ -170,6 +178,7 @@ SELECT COUNT(DISTINCT a, b)
SELECT COUNT(DISTINCT a, b + 1)
SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x
SELECT TRUNCATE(a, b)
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x
@ -622,7 +631,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event
@ -643,3 +652,21 @@ ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
SELECT div.a FROM test_table AS div
WITH view AS (SELECT 1 AS x) SELECT * FROM view
CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
ARRAY<STRUCT<INT, DOUBLE, ARRAY<INT>>>
ARRAY<INT>[1, 2, 3]
ARRAY<INT>[]
STRUCT<x VARCHAR(10)>
STRUCT<x VARCHAR(10)>("bla")
STRUCT<VARCHAR(10)>("bla")
STRUCT<INT>(5)
STRUCT<DATE>("2011-05-05")
STRUCT<x INT, y TEXT>(1, t.str_col)
SELECT CAST(NULL AS ARRAY<INT>) IS NULL AS array_is_null
CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)
CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))

View file

@ -322,3 +322,23 @@ SELECT
* /* multi
line
comment */;
WITH table_data AS (
SELECT 'bob' AS name, ARRAY['banana', 'apple', 'orange'] AS fruit_basket
)
SELECT
name,
fruit,
basket_index
FROM table_data
CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index;
WITH table_data AS (
SELECT
'bob' AS name,
ARRAY('banana', 'apple', 'orange') AS fruit_basket
)
SELECT
name,
fruit,
basket_index
FROM table_data
CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index;

View file

@ -624,6 +624,10 @@ FROM foo""",
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)
empty_string = exp.to_table("")
self.assertEqual(empty_string.name, "")
self.assertIsNone(table_only.args.get("db"))
self.assertIsNone(table_only.args.get("catalog"))
def test_to_column(self):
column_only = exp.to_column("column_name")
@ -715,3 +719,9 @@ FROM foo""",
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
def test_rename_table(self):
self.assertEqual(
exp.rename_table("t1", "t2").sql(),
"ALTER TABLE t1 RENAME TO t2",
)

View file

@ -6,7 +6,7 @@ from pandas.testing import assert_frame_equal
import sqlglot
from sqlglot import exp, optimizer, parse_one
from sqlglot.errors import OptimizeError
from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema
@ -161,7 +161,7 @@ class TestOptimizer(unittest.TestCase):
def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
with self.assertRaises(OptimizeError):
with self.assertRaises((OptimizeError, SchemaError)):
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_lower_identities(self):

View file

@ -325,3 +325,9 @@ class TestParser(unittest.TestCase):
"Expected table name",
logger,
)
def test_rename_table(self):
self.assertEqual(
parse_one("ALTER TABLE foo RENAME TO bar").sql(),
"ALTER TABLE foo RENAME TO bar",
)

View file

@ -272,6 +272,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
)
self.validate(
"SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)",
"SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)",
write="presto",
)
def test_alter(self):
self.validate(
@ -447,6 +452,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
self.assertEqual(generated, pretty)
self.assertEqual(parse_one(sql), parse_one(pretty))
def test_pretty_line_breaks(self):
self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'")
@mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger):
invalid = "x + 1. ("