1
0
Fork 0

Merging upstream version 10.5.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:05:06 +01:00
parent 3b8c9606bf
commit 599f59b0f8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
39 changed files with 786 additions and 133 deletions

View file

@ -462,7 +462,7 @@ make check # Set SKIP_INTEGRATION=1 to skip integration tests
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide | | Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) | | tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) | | short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76E-5 (0.080) |
| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) | | long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) | | crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |

View file

@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.5.2" __version__ = "10.5.6"
pretty = False pretty = False

View file

@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.teradata import Teradata
from sqlglot.dialects.trino import Trino from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL from sqlglot.dialects.tsql import TSQL

View file

@ -165,6 +165,11 @@ class BigQuery(Dialect):
TokenType.TABLE, TokenType.TABLE,
} }
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS, # type: ignore
TokenType.VALUES,
}
class Generator(generator.Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore **generator.Generator.TRANSFORMS, # type: ignore

View file

@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.errors import ParseError
from sqlglot.parser import parse_var_map from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -72,6 +73,30 @@ class ClickHouse(Dialect):
return this return this
def _parse_position(self) -> exp.Expression:
this = super()._parse_position()
# clickhouse position args are swapped
substr = this.this
this.args["this"] = this.args.get("substr")
this.args["substr"] = substr
return this
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
return super()._parse_cte()
except ParseError:
# WITH <expression> AS <identifier>
self._retreat(index)
statement = self._parse_statement()
if statement and isinstance(statement.this, exp.Alias):
self.raise_error("Expected CTE to have alias")
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
class Generator(generator.Generator): class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
@ -110,3 +135,9 @@ class ClickHouse(Dialect):
params = self.format_args(self.expressions(expression, params_name)) params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name)) args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})" return f"({params})({args})"
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
return super().cte_sql(expression)

View file

@ -33,6 +33,7 @@ class Dialects(str, Enum):
TSQL = "tsql" TSQL = "tsql"
DATABRICKS = "databricks" DATABRICKS = "databricks"
DRILL = "drill" DRILL = "drill"
TERADATA = "teradata"
class _Dialect(type): class _Dialect(type):
@ -368,7 +369,7 @@ def locate_to_strposition(args):
) )
def strposition_to_local_sql(self, expression): def strposition_to_locate_sql(self, expression):
args = self.format_args( args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position") expression.args.get("substr"), expression.this, expression.args.get("position")
) )

View file

@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql, no_safe_divide_sql,
no_trycast_sql, no_trycast_sql,
rename_func, rename_func,
strposition_to_local_sql, strposition_to_locate_sql,
struct_extract_sql, struct_extract_sql,
timestrtotime_sql, timestrtotime_sql,
var_map_sql, var_map_sql,
@ -297,7 +297,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"), exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_local_sql, exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date, exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time, exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix, exp.StrToUnix: _str_to_unix,

View file

@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql, no_paren_current_date_sql,
no_tablesample_sql, no_tablesample_sql,
no_trycast_sql, no_trycast_sql,
strposition_to_local_sql, strposition_to_locate_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -122,6 +122,8 @@ class MySQL(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"LONGTEXT": TokenType.LONGTEXT,
"START": TokenType.BEGIN, "START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR, "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER, "_ARMSCII8": TokenType.INTRODUCER,
@ -442,7 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql, exp.StrPosition: strposition_to_locate_sql,
} }
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
@ -454,6 +456,10 @@ class MySQL(Dialect):
exp.LikeProperty, exp.LikeProperty,
} }
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression): def show_sql(self, expression):

View file

@ -223,19 +223,15 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE, "~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE, "~*": TokenType.IRLIKE,
"~": TokenType.RLIKE, "~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND, "BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN, "BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL, "BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR, "CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND, "COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND, "DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND, "DO": TokenType.COMMAND,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND, "GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE, "HSTORE": TokenType.HSTORE,
"IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB, "JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND, "REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND, "REINDEX": TokenType.COMMAND,
@ -299,6 +295,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql, exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql, exp.Substring: _substring_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.Trim: trim_sql, exp.Trim: trim_sql,

View file

@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
no_ilike_sql, no_ilike_sql,
no_safe_divide_sql, no_safe_divide_sql,
rename_func, rename_func,
str_position_sql,
struct_extract_sql, struct_extract_sql,
timestrtotime_sql, timestrtotime_sql,
) )
@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression):
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _concat_ws_sql(self, expression):
sep, *args = expression.expressions
sep = self.sql(sep)
if len(args) > 1:
return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
def _datatype_sql(self, expression): def _datatype_sql(self, expression):
sql = self.datatype_sql(expression) sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ: if expression.this == exp.DataType.Type.TIMESTAMPTZ:
@ -61,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression): def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset")) _ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})" return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
def _encode_sql(self, expression): def _encode_sql(self, expression):
@ -119,6 +110,38 @@ def _ensure_utf8(charset):
raise UnsupportedError(f"Unsupported charset {charset}") raise UnsupportedError(f"Unsupported charset {charset}")
def _approx_percentile(args):
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
weight=seq_get(args, 1),
quantile=seq_get(args, 2),
accuracy=seq_get(args, 3),
)
if len(args) == 3:
return exp.ApproxQuantile(
this=seq_get(args, 0),
quantile=seq_get(args, 1),
accuracy=seq_get(args, 2),
)
return exp.ApproxQuantile.from_arg_list(args)
def _from_unixtime(args):
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
hours=seq_get(args, 1),
minutes=seq_get(args, 2),
)
if len(args) == 2:
return exp.UnixToTime(
this=seq_get(args, 0),
zone=seq_get(args, 1),
)
return exp.UnixToTime.from_arg_list(args)
class Presto(Dialect): class Presto(Dialect):
index_offset = 1 index_offset = 1
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
@ -150,19 +173,25 @@ class Presto(Dialect):
), ),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list, "FROM_UNIXTIME": _from_unixtime,
"STRPOS": exp.StrPosition.from_arg_list, "STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list, "FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list, "TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode( "TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8") this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
), ),
"FROM_UTF8": lambda args: exp.Decode( "FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8") this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
), ),
} }
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator): class Generator(generator.Generator):
@ -194,7 +223,6 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@ -209,12 +237,13 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql, exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql, exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Quantile: _quantile_sql, exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql, exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql, exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array, exp.SortArray: _no_sort_array,
exp.StrPosition: str_position_sql, exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql, exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
@ -233,6 +262,7 @@ class Presto(Dialect):
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
} }
def transaction_sql(self, expression): def transaction_sql(self, expression):

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from sqlglot import exp, transforms from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.postgres import Postgres
@ -21,6 +23,19 @@ class Redshift(Postgres):
"NVL": exp.Coalesce.from_arg_list, "NVL": exp.Coalesce.from_arg_list,
} }
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func)
if (
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.expressions
and this.expressions[0] == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
return this
class Tokenizer(Postgres.Tokenizer): class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"] ESCAPES = ["\\"]
@ -52,6 +67,10 @@ class Redshift(Postgres):
exp.DistStyleProperty, exp.DistStyleProperty,
} }
WITH_PROPERTIES = {
exp.LikeProperty,
}
TRANSFORMS = { TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore **Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore
@ -60,3 +79,57 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"), exp.Matches: rename_func("DECODE"),
} }
def values_sql(self, expression: exp.Values) -> str:
"""
Converts `VALUES...` expression into a series of unions.
Note: If you have a lot of unions then this will result in a large number of recursive statements to
evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
very slow.
"""
if not isinstance(expression.unnest().parent, exp.From):
return super().values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
selects = []
for i, row in enumerate(rows):
if i == 0:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
selects.append(exp.Select(expressions=row))
subquery_expression = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(subquery_expression, select, distinct=False)
return self.subquery_sql(subquery_expression.subquery(expression.alias))
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def renametable_sql(self, expression: exp.RenameTable) -> str:
"""Redshift only supports defining the table name itself (not the db) when renaming tables"""
expression = expression.copy()
target_table = expression.this
for arg in target_table.args:
if arg != "this":
target_table.set(arg, None)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
`TEXT` to `VARCHAR`.
"""
if expression.this == exp.DataType.Type.TEXT:
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
return super().datatype_sql(expression)

View file

@ -210,6 +210,7 @@ class Snowflake(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore **generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateAdd: rename_func("DATEADD"),
exp.DateStrToDate: datestrtodate_sql, exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"), exp.If: rename_func("IFF"),
@ -218,7 +219,7 @@ class Snowflake(Dialect):
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"), exp.Matches: rename_func("DECODE"),
exp.StrPosition: rename_func("POSITION"), exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",

View file

@ -124,6 +124,7 @@ class Spark(Hive):
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
} }
TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.ILike)

View file

@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
def _fetch_sql(self, expression):
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
# https://www.sqlite.org/lang_aggfunc.html#group_concat # https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression): def _group_concat_sql(self, expression):
this = expression.this this = expression.this
@ -30,6 +34,14 @@ def _group_concat_sql(self, expression):
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
class SQLite(Dialect): class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"] IDENTIFIERS = ['"', ("[", "]"), "`"]
@ -71,6 +83,7 @@ class SQLite(Dialect):
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore **generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: _date_add_sql,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@ -78,8 +91,11 @@ class SQLite(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"), exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql, exp.TableSample: no_tablesample_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql, exp.GroupConcat: _group_concat_sql,
exp.Fetch: _fetch_sql,
} }
def transaction_sql(self, expression): def transaction_sql(self, expression):

View file

@ -0,0 +1,87 @@
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class Teradata(Dialect):
class Parser(parser.Parser):
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
"GRAPHIC_TO_LATIN",
"GRAPHIC_TO_UNICODE",
"GRAPHIC_TO_UNICODE_PadSpace",
"KANJI1_KanjiEBCDIC_TO_UNICODE",
"KANJI1_KanjiEUC_TO_UNICODE",
"KANJI1_KANJISJIS_TO_UNICODE",
"KANJI1_SBC_TO_UNICODE",
"KANJISJIS_TO_GRAPHIC",
"KANJISJIS_TO_LATIN",
"KANJISJIS_TO_UNICODE",
"LATIN_TO_GRAPHIC",
"LATIN_TO_KANJISJIS",
"LATIN_TO_UNICODE",
"LOCALE_TO_UNICODE",
"UNICODE_TO_GRAPHIC",
"UNICODE_TO_GRAPHIC_PadGraphic",
"UNICODE_TO_GRAPHIC_VarGraphic",
"UNICODE_TO_KANJI1_KanjiEBCDIC",
"UNICODE_TO_KANJI1_KanjiEUC",
"UNICODE_TO_KANJI1_KANJISJIS",
"UNICODE_TO_KANJI1_SBC",
"UNICODE_TO_KANJISJIS",
"UNICODE_TO_LATIN",
"UNICODE_TO_LOCALE",
"UNICODE_TO_UNICODE_FoldSpace",
"UNICODE_TO_UNICODE_Fullwidth",
"UNICODE_TO_UNICODE_Halfwidth",
"UNICODE_TO_UNICODE_NFC",
"UNICODE_TO_UNICODE_NFD",
"UNICODE_TO_UNICODE_NFKC",
"UNICODE_TO_UNICODE_NFKD",
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
def _parse_translate(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.USING):
self.raise_error("Expected USING in TRANSLATE")
if self._match_texts(self.CHARSET_TRANSLATORS):
charset_split = self._prev.text.split("_TO_")
to = self.expression(exp.CharacterSet, this=charset_split[1])
else:
self.raise_error("Expected a character set translator after USING in TRANSLATE")
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def _parse_update(self) -> exp.Expression:
return self.expression(
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"from": self._parse_from(),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
},
)
class Generator(generator.Generator):
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
from_sql = self.sql(expression, "from")
set_sql = self.expressions(expression, flat=True)
where_sql = self.sql(expression, "where")
sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
return self.prepend_ctes(expression, sql)

View file

@ -243,28 +243,34 @@ class TSQL(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN, "BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME, "DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP, "DECLARE": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE, "IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY, "MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY, "NTEXT": TokenType.TEXT,
"ROWVERSION": TokenType.ROWVERSION,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
"NVARCHAR(MAX)": TokenType.TEXT, "NVARCHAR(MAX)": TokenType.TEXT,
"VARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
} }
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore **parser.Parser.FUNCTIONS, # type: ignore
"CHARINDEX": exp.StrPosition.from_arg_list, "CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"ISNULL": exp.Coalesce.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@ -288,7 +294,7 @@ class TSQL(Dialect):
} }
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
TABLE_PREFIX_TOKENS = {TokenType.HASH} TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
def _parse_convert(self, strict): def _parse_convert(self, strict):
to = self._parse_types() to = self._parse_types()

View file

@ -653,6 +653,7 @@ class Create(Expression):
"statistics": False, "statistics": False,
"no_primary_index": False, "no_primary_index": False,
"indexes": False, "indexes": False,
"no_schema_binding": False,
} }
@ -770,6 +771,10 @@ class AlterColumn(Expression):
} }
class RenameTable(Expression):
pass
class ColumnConstraint(Expression): class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True} arg_types = {"this": False, "kind": True}
@ -804,7 +809,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT # this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False} arg_types = {"this": True, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind): class NotNullColumnConstraint(ColumnConstraintKind):
@ -1266,7 +1271,7 @@ class Tuple(Expression):
class Subqueryable(Unionable): class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True): def subquery(self, alias=None, copy=True) -> Subquery:
""" """
Convert this expression to an aliased expression that can be used as a Subquery. Convert this expression to an aliased expression that can be used as a Subquery.
@ -1460,6 +1465,7 @@ class Unnest(UDTF):
"expressions": True, "expressions": True,
"ordinality": False, "ordinality": False,
"alias": False, "alias": False,
"offset": False,
} }
@ -2126,6 +2132,7 @@ class DataType(Expression):
"this": True, "this": True,
"expressions": False, "expressions": False,
"nested": False, "nested": False,
"values": False,
} }
class Type(AutoName): class Type(AutoName):
@ -2134,6 +2141,8 @@ class DataType(Expression):
VARCHAR = auto() VARCHAR = auto()
NVARCHAR = auto() NVARCHAR = auto()
TEXT = auto() TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto() BINARY = auto()
VARBINARY = auto() VARBINARY = auto()
INT = auto() INT = auto()
@ -2791,7 +2800,7 @@ class Day(Func):
class Decode(Func): class Decode(Func):
arg_types = {"this": True, "charset": True} arg_types = {"this": True, "charset": True, "replace": False}
class DiToDate(Func): class DiToDate(Func):
@ -2815,7 +2824,7 @@ class Floor(Func):
class Greatest(Func): class Greatest(Func):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": False}
is_var_len_args = True is_var_len_args = True
@ -2861,7 +2870,7 @@ class JSONBExtractScalar(JSONExtract):
class Least(Func): class Least(Func):
arg_types = {"this": True, "expressions": True} arg_types = {"this": True, "expressions": False}
is_var_len_args = True is_var_len_args = True
@ -2904,7 +2913,7 @@ class Lower(Func):
class Map(Func): class Map(Func):
arg_types = {"keys": True, "values": True} arg_types = {"keys": False, "values": False}
class VarMap(Func): class VarMap(Func):
@ -2923,11 +2932,11 @@ class Matches(Func):
class Max(AggFunc): class Max(AggFunc):
pass arg_types = {"this": True, "expression": False}
class Min(AggFunc): class Min(AggFunc):
pass arg_types = {"this": True, "expression": False}
class Month(Func): class Month(Func):
@ -2962,7 +2971,7 @@ class QuantileIf(AggFunc):
class ApproxQuantile(Quantile): class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False} arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
class ReadCSV(Func): class ReadCSV(Func):
@ -3022,7 +3031,12 @@ class Substring(Func):
class StrPosition(Func): class StrPosition(Func):
arg_types = {"substr": True, "this": True, "position": False} arg_types = {
"this": True,
"substr": True,
"position": False,
"instance": False,
}
class StrToDate(Func): class StrToDate(Func):
@ -3129,8 +3143,10 @@ class UnixToStr(Func):
arg_types = {"this": True, "format": False} arg_types = {"this": True, "format": False}
# https://prestodb.io/docs/current/functions/datetime.html
# presto has weird zone/hours/minutes
class UnixToTime(Func): class UnixToTime(Func):
arg_types = {"this": True, "scale": False} arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds") SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis") MILLIS = Literal.string("millis")
@ -3684,6 +3700,16 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier return identifier
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@t.overload
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
""" """
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@ -3860,6 +3886,26 @@ def values(
) )
def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
"""Build ALTER TABLE... RENAME... expression
Args:
old_name: The old name of the table
new_name: The new name of the table
Returns:
Alter table expression
"""
old_table = to_table(old_name)
new_table = to_table(new_name)
return AlterTable(
this=old_table,
actions=[
RenameTable(this=new_table),
],
)
def convert(value) -> Expression: def convert(value) -> Expression:
"""Convert a python value into an expression object. """Convert a python value into an expression object.

View file

@ -82,6 +82,8 @@ class Generator:
TYPE_MAPPING = { TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
} }
TOKEN_MAPPING: t.Dict[TokenType, str] = {} TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@ -105,6 +107,7 @@ class Generator:
} }
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = ( __slots__ = (
"time_mapping", "time_mapping",
@ -211,6 +214,8 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
if self.pretty:
sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql return sql
def unsupported(self, message: str) -> None: def unsupported(self, message: str) -> None:
@ -401,7 +406,17 @@ class Generator:
def generatedasidentitycolumnconstraint_sql( def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str: ) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
increment = f"INCREMENT BY {increment}" if increment else ""
sequence_opts = ""
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
return (
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
)
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -475,10 +490,13 @@ class Generator:
materialized, materialized,
) )
) )
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
post_expression_modifiers = "".join((data, statistics, no_primary_index)) post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}" expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql) return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str: def describe_sql(self, expression: exp.Describe) -> str:
@ -517,13 +535,19 @@ class Generator:
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = "" nested = ""
interior = self.expressions(expression, flat=True) interior = self.expressions(expression, flat=True)
values = ""
if interior: if interior:
nested = ( if expression.args.get("nested"):
f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("nested") if expression.args.get("values") is not None:
else f"({interior})" delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = (
f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
) )
return f"{type_sql}{nested}" else:
nested = f"({interior})"
return f"{type_sql}{nested}{values}"
def directory_sql(self, expression: exp.Directory) -> str: def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else "" local = "LOCAL " if expression.args.get("local") else ""
@ -622,10 +646,14 @@ class Generator:
return self.sep() + self.expressions(properties, indent=False, sep=" ") return self.sep() + self.expressions(properties, indent=False, sep=" ")
return "" return ""
def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str: def properties(
self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
) -> str:
if properties.expressions: if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False) expressions = self.expressions(properties, sep=sep, indent=False)
return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}" return (
f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
)
return "" return ""
def with_properties(self, properties: exp.Properties) -> str: def with_properties(self, properties: exp.Properties) -> str:
@ -763,14 +791,15 @@ class Generator:
return self.prepend_ctes(expression, sql) return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str: def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression) args = self.expressions(expression)
if not alias: alias = self.sql(expression, "alias")
return f"VALUES{self.seg('')}{args}" values = f"VALUES{self.seg('')}{args}"
alias = f" AS {alias}" if alias else alias values = (
if self.WRAP_DERIVED_VALUES: f"({values})"
return f"(VALUES{self.seg('')}{args}){alias}" if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
return f"VALUES{self.seg('')}{args}{alias}" else values
)
return f"{values} AS {alias}" if alias else values
def var_sql(self, expression: exp.Var) -> str: def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this") return self.sql(expression, "this")
@ -868,6 +897,8 @@ class Generator:
if self._replace_backslash: if self._replace_backslash:
text = text.replace("\\", "\\\\") text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end) text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}" text = f"{self.quote_start}{text}{self.quote_end}"
return text return text
@ -1036,7 +1067,9 @@ class Generator:
alias = self.sql(expression, "alias") alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}" offset = expression.args.get("offset")
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str: def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this")) this = self.indent(self.sql(expression, "this"))
@ -1132,15 +1165,14 @@ class Generator:
return f"EXTRACT({this} FROM {expression_sql})" return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str: def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position") trim_type = self.sql(expression, "position")
if trim_type == "LEADING": if trim_type == "LEADING":
return f"LTRIM({target})" return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
elif trim_type == "TRAILING": elif trim_type == "TRAILING":
return f"RTRIM({target})" return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
else: else:
return f"TRIM({target})" return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
def concat_sql(self, expression: exp.Concat) -> str: def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1: if len(expression.expressions) == 1:
@ -1317,6 +1349,10 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT" return f"ALTER COLUMN {this} DROP DEFAULT"
def renametable_sql(self, expression: exp.RenameTable) -> str:
this = self.sql(expression, "this")
return f"RENAME TO {this}"
def altertable_sql(self, expression: exp.AlterTable) -> str: def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"] actions = expression.args["actions"]
@ -1326,7 +1362,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop): elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions") actions = self.expressions(expression, "actions")
elif isinstance(actions[0], exp.AlterColumn): elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
actions = self.sql(actions[0]) actions = self.sql(actions[0])
else: else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")

View file

@ -52,7 +52,10 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (sequence): sequence of optimizer rules to use rules (sequence): sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns: Returns:
sqlglot.Expression: optimized expression sqlglot.Expression: optimized expression

View file

@ -1,7 +1,7 @@
import itertools import itertools
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.errors import OptimizeError from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema from sqlglot.schema import ensure_schema
@ -382,7 +382,7 @@ class _Resolver:
try: try:
return self.schema.column_names(source, only_visible) return self.schema.column_names(source, only_visible)
except Exception as e: except Exception as e:
raise OptimizeError(str(e)) from e raise SchemaError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values): if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names return source.expression.alias_column_names

View file

@ -107,6 +107,8 @@ class Parser(metaclass=_Parser):
TokenType.VARCHAR, TokenType.VARCHAR,
TokenType.NVARCHAR, TokenType.NVARCHAR,
TokenType.TEXT, TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
TokenType.BINARY, TokenType.BINARY,
TokenType.VARBINARY, TokenType.VARBINARY,
TokenType.JSON, TokenType.JSON,
@ -233,6 +235,7 @@ class Parser(metaclass=_Parser):
TokenType.UNPIVOT, TokenType.UNPIVOT,
TokenType.PROPERTIES, TokenType.PROPERTIES,
TokenType.PROCEDURE, TokenType.PROCEDURE,
TokenType.VIEW,
TokenType.VOLATILE, TokenType.VOLATILE,
TokenType.WINDOW, TokenType.WINDOW,
*SUBQUERY_PREDICATES, *SUBQUERY_PREDICATES,
@ -252,6 +255,7 @@ class Parser(metaclass=_Parser):
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = { FUNC_TOKENS = {
TokenType.COMMAND,
TokenType.CURRENT_DATE, TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME, TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIMESTAMP,
@ -552,7 +556,7 @@ class Parser(metaclass=_Parser):
TokenType.IF: lambda self: self._parse_if(), TokenType.IF: lambda self: self._parse_if(),
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False), "TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(), "EXTRACT": lambda self: self._parse_extract(),
@ -937,6 +941,7 @@ class Parser(metaclass=_Parser):
statistics = None statistics = None
no_primary_index = None no_primary_index = None
indexes = None indexes = None
no_schema_binding = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function() this = self._parse_user_defined_function()
@ -975,6 +980,9 @@ class Parser(metaclass=_Parser):
break break
else: else:
indexes.append(index) indexes.append(index)
elif create_token.token_type == TokenType.VIEW:
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
return self.expression( return self.expression(
exp.Create, exp.Create,
@ -993,6 +1001,7 @@ class Parser(metaclass=_Parser):
statistics=statistics, statistics=statistics,
no_primary_index=no_primary_index, no_primary_index=no_primary_index,
indexes=indexes, indexes=indexes,
no_schema_binding=no_schema_binding,
) )
def _parse_property(self) -> t.Optional[exp.Expression]: def _parse_property(self) -> t.Optional[exp.Expression]:
@ -1246,9 +1255,15 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self) -> exp.Expression: def _parse_value(self) -> exp.Expression:
expressions = self._parse_wrapped_csv(self._parse_conjunction) if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions) return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_select( def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
) -> t.Optional[exp.Expression]: ) -> t.Optional[exp.Expression]:
@ -1313,19 +1328,9 @@ class Parser(metaclass=_Parser):
# Union ALL should be a property of the top select node, not the subquery # Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this, parse_alias=parse_subquery_alias) return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES): elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
expressions = self._parse_csv(self._parse_value)
else:
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# Source: https://prestodb.io/docs/current/sql/values.html
expressions = self._parse_csv(
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
)
this = self.expression( this = self.expression(
exp.Values, exp.Values,
expressions=expressions, expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(), alias=self._parse_table_alias(),
) )
else: else:
@ -1612,13 +1617,12 @@ class Parser(metaclass=_Parser):
if alias: if alias:
this.set("alias", alias) this.set("alias", alias)
if self._match(TokenType.WITH): if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set( this.set(
"hints", "hints",
self._parse_wrapped_csv( self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
lambda: self._parse_function() or self._parse_var(any_token=True)
),
) )
self._match_r_paren()
if not self.alias_post_tablesample: if not self.alias_post_tablesample:
table_sample = self._parse_table_sample() table_sample = self._parse_table_sample()
@ -1643,8 +1647,17 @@ class Parser(metaclass=_Parser):
alias.set("columns", [alias.this]) alias.set("columns", [alias.this])
alias.set("this", None) alias.set("this", None)
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_conjunction()
return self.expression( return self.expression(
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias exp.Unnest,
expressions=expressions,
ordinality=ordinality,
alias=alias,
offset=offset,
) )
def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
@ -1999,7 +2012,7 @@ class Parser(metaclass=_Parser):
this = self._parse_column() this = self._parse_column()
if type_token: if type_token:
if this: if this and not isinstance(this, exp.Star):
return self.expression(exp.Cast, this=this, to=type_token) return self.expression(exp.Cast, this=this, to=type_token)
if not type_token.args.get("expressions"): if not type_token.args.get("expressions"):
self._retreat(index) self._retreat(index)
@ -2050,6 +2063,7 @@ class Parser(metaclass=_Parser):
self._retreat(index) self._retreat(index)
return None return None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT): if nested and self._match(TokenType.LT):
if is_struct: if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs) expressions = self._parse_csv(self._parse_struct_kwargs)
@ -2059,6 +2073,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT): if not self._match(TokenType.GT):
self.raise_error("Expecting >") self.raise_error("Expecting >")
if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)):
values = self._parse_csv(self._parse_conjunction)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
value: t.Optional[exp.Expression] = None value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS: if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
@ -2097,9 +2115,13 @@ class Parser(metaclass=_Parser):
this=exp.DataType.Type[type_token.value.upper()], this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions, expressions=expressions,
nested=nested, nested=nested,
values=values,
) )
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
if self._curr and self._curr.token_type in self.TYPE_TOKENS:
return self._parse_types()
this = self._parse_id_var() this = self._parse_id_var()
self._match(TokenType.COLON) self._match(TokenType.COLON)
data_type = self._parse_types() data_type = self._parse_types()
@ -2412,6 +2434,14 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ALWAYS) self._match(TokenType.ALWAYS)
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
kind.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
kind.set("increment", self._parse_bitwise())
self._match_r_paren()
else: else:
return this return this
@ -2619,8 +2649,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.IN): if self._match(TokenType.IN):
args.append(self._parse_bitwise()) args.append(self._parse_bitwise())
# Note: we're parsing in order needle, haystack, position this = exp.StrPosition(
this = exp.StrPosition.from_arg_list(args) this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
self.validate_expression(this, args) self.validate_expression(this, args)
return this return this
@ -2999,6 +3033,8 @@ class Parser(metaclass=_Parser):
actions = self._parse_csv(self._parse_add_column) actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False): elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column) actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"): elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN) self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True) column = self._parse_field(any_token=True)

View file

@ -82,6 +82,8 @@ class TokenType(AutoName):
VARCHAR = auto() VARCHAR = auto()
NVARCHAR = auto() NVARCHAR = auto()
TEXT = auto() TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
BINARY = auto() BINARY = auto()
VARBINARY = auto() VARBINARY = auto()
JSON = auto() JSON = auto()
@ -434,6 +436,8 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"] ESCAPES = ["'"]
_ESCAPES: t.Set[str] = set()
KEYWORDS = { KEYWORDS = {
**{ **{
f"{key}{postfix}": TokenType.BLOCK_START f"{key}{postfix}": TokenType.BLOCK_START
@ -461,6 +465,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>>": TokenType.DHASH_ARROW, "#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW, "<->": TokenType.LR_ARROW,
"ALL": TokenType.ALL, "ALL": TokenType.ALL,
"ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND, "AND": TokenType.AND,
"ANTI": TokenType.ANTI, "ANTI": TokenType.ANTI,
"ANY": TokenType.ANY, "ANY": TokenType.ANY,
@ -472,6 +477,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN, "BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH, "BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET, "BUCKET": TokenType.BUCKET,
"BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE, "CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE, "UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE, "CASE": TokenType.CASE,
@ -521,9 +527,11 @@ class Tokenizer(metaclass=_Tokenizer):
"FOREIGN KEY": TokenType.FOREIGN_KEY, "FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT, "FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM, "FROM": TokenType.FROM,
"GENERATED": TokenType.GENERATED,
"GROUP BY": TokenType.GROUP_BY, "GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS, "GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING, "HAVING": TokenType.HAVING,
"IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF, "IF": TokenType.IF,
"ILIKE": TokenType.ILIKE, "ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE, "IMMUTABLE": TokenType.IMMUTABLE,
@ -746,7 +754,7 @@ class Tokenizer(metaclass=_Tokenizer):
) )
def __init__(self) -> None: def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES # type: ignore self._replace_backslash = "\\" in self._ESCAPES
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
@ -771,7 +779,10 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset() self.reset()
self.sql = sql self.sql = sql
self.size = len(sql) self.size = len(sql)
self._scan()
return self.tokens
def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end: while self.size and not self._end:
self._start = self._current self._start = self._current
self._advance() self._advance()
@ -792,7 +803,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_identifier(identifier_end) self._scan_identifier(identifier_end)
else: else:
self._scan_keywords() self._scan_keywords()
return self.tokens
if until and until():
break
def _chars(self, size: int) -> str: def _chars(self, size: int) -> str:
if size == 1: if size == 1:
@ -832,11 +845,13 @@ class Tokenizer(metaclass=_Tokenizer):
if token_type in self.COMMANDS and ( if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
): ):
self._start = self._current start = self._current
while not self._end and self._peek != ";": tokens = len(self.tokens)
self._advance() self._scan(lambda: self._peek == ";")
if self._start < self._current: self.tokens = self.tokens[:tokens]
self._add(TokenType.STRING) text = self.sql[start : self._current].strip()
if text:
self._add(TokenType.STRING, text)
def _scan_keywords(self) -> None: def _scan_keywords(self) -> None:
size = 0 size = 0
@ -947,7 +962,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.isidentifier(): # type: ignore elif self._peek.isidentifier(): # type: ignore
number_text = self._text number_text = self._text
literal = [] literal = []
while self._peek.isidentifier(): # type: ignore
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
literal.append(self._peek.upper()) # type: ignore literal.append(self._peek.upper()) # type: ignore
self._advance() self._advance()
@ -1063,8 +1079,12 @@ class Tokenizer(metaclass=_Tokenizer):
delim_size = len(delimiter) delim_size = len(delimiter)
while True: while True:
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore if (
text += delimiter self._char in self._ESCAPES
and self._peek
and (self._peek == delimiter or self._peek in self._ESCAPES)
):
text += self._peek
self._advance(2) self._advance(2)
else: else:
if self._chars(delim_size) == delimiter: if self._chars(delim_size) == delimiter:

View file

@ -6,6 +6,8 @@ class TestBigQuery(Validator):
dialect = "bigquery" dialect = "bigquery"
def test_bigquery(self): def test_bigquery(self):
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_all( self.validate_all(
"REGEXP_CONTAINS('foo', '.*')", "REGEXP_CONTAINS('foo', '.*')",
read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
@ -41,6 +43,15 @@ class TestBigQuery(Validator):
"spark": r"'/\\*.*\\*/'", "spark": r"'/\\*.*\\*/'",
}, },
) )
self.validate_all(
r"'\\'",
write={
"bigquery": r"'\\'",
"duckdb": r"'\'",
"presto": r"'\'",
"hive": r"'\\'",
},
)
self.validate_all( self.validate_all(
R'R"""/\*.*\*/"""', R'R"""/\*.*\*/"""',
write={ write={

View file

@ -17,6 +17,7 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
self.validate_identity("position(a, b)")
self.validate_all( self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@ -47,3 +48,9 @@ class TestClickhouse(Validator):
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
}, },
) )
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")

View file

@ -14,7 +14,7 @@ class Validator(unittest.TestCase):
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression return expression
def validate_all(self, sql, read=None, write=None, pretty=False): def validate_all(self, sql, read=None, write=None, pretty=False, identify=False):
""" """
Validate that: Validate that:
1. Everything in `read` transpiles to `sql` 1. Everything in `read` transpiles to `sql`
@ -32,7 +32,10 @@ class Validator(unittest.TestCase):
with self.subTest(f"{read_dialect} -> {sql}"): with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual( self.assertEqual(
parse_one(read_sql, read_dialect).sql( parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty self.dialect,
unsupported_level=ErrorLevel.IGNORE,
pretty=pretty,
identify=identify,
), ),
sql, sql,
) )
@ -48,6 +51,7 @@ class Validator(unittest.TestCase):
write_dialect, write_dialect,
unsupported_level=ErrorLevel.IGNORE, unsupported_level=ErrorLevel.IGNORE,
pretty=pretty, pretty=pretty,
identify=identify,
), ),
write_sql, write_sql,
) )
@ -76,7 +80,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)", "oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)", "postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)", "redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)", "spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)",
@ -155,7 +159,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)", "oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)", "postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)", "presto": "CAST(a AS VARCHAR)",
"redshift": "CAST(a AS TEXT)", "redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)", "snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)", "spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)",
@ -344,6 +348,7 @@ class TestDialect(Validator):
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)",
"sqlite": "'2020-01-01'",
}, },
) )
self.validate_all( self.validate_all(
@ -373,7 +378,7 @@ class TestDialect(Validator):
"duckdb": "CAST(x AS TEXT)", "duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)", "hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)", "presto": "CAST(x AS VARCHAR)",
"redshift": "CAST(x AS TEXT)", "redshift": "CAST(x AS VARCHAR(MAX))",
}, },
) )
self.validate_all( self.validate_all(
@ -488,7 +493,9 @@ class TestDialect(Validator):
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"postgres": "x + INTERVAL '1' 'day'", "postgres": "x + INTERVAL '1' 'day'",
"presto": "DATE_ADD('day', 1, x)", "presto": "DATE_ADD('day', 1, x)",
"snowflake": "DATEADD(x, 1, 'day')",
"spark": "DATE_ADD(x, 1)", "spark": "DATE_ADD(x, 1)",
"sqlite": "DATE(x, '1 day')",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
"tsql": "DATEADD(day, 1, x)", "tsql": "DATEADD(day, 1, x)",
}, },
@ -594,6 +601,7 @@ class TestDialect(Validator):
"hive": "TO_DATE(x)", "hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)", "spark": "TO_DATE(x)",
"sqlite": "x",
}, },
) )
self.validate_all( self.validate_all(
@ -955,7 +963,7 @@ class TestDialect(Validator):
}, },
) )
self.validate_all( self.validate_all(
"STR_POSITION('a', x)", "STR_POSITION(x, 'a')",
write={ write={
"drill": "STRPOS(x, 'a')", "drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')",
@ -971,7 +979,7 @@ class TestDialect(Validator):
"POSITION('a', x, 3)", "POSITION('a', x, 3)",
write={ write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "presto": "STRPOS(x, 'a', 3)",
"spark": "LOCATE('a', x, 3)", "spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)", "clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)", "snowflake": "POSITION('a', x, 3)",
@ -982,9 +990,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', 'a', 'b')", "CONCAT_WS('-', 'a', 'b')",
write={ write={
"duckdb": "CONCAT_WS('-', 'a', 'b')", "duckdb": "CONCAT_WS('-', 'a', 'b')",
"presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')", "presto": "CONCAT_WS('-', 'a', 'b')",
"hive": "CONCAT_WS('-', 'a', 'b')", "hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')",
"trino": "CONCAT_WS('-', 'a', 'b')",
}, },
) )
@ -992,9 +1001,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', x)", "CONCAT_WS('-', x)",
write={ write={
"duckdb": "CONCAT_WS('-', x)", "duckdb": "CONCAT_WS('-', x)",
"presto": "ARRAY_JOIN(x, '-')",
"hive": "CONCAT_WS('-', x)", "hive": "CONCAT_WS('-', x)",
"presto": "CONCAT_WS('-', x)",
"spark": "CONCAT_WS('-', x)", "spark": "CONCAT_WS('-', x)",
"trino": "CONCAT_WS('-', x)",
}, },
) )
self.validate_all( self.validate_all(
@ -1118,6 +1128,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY", "SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY",
write={ write={
"sqlite": "SELECT x FROM y LIMIT 3 OFFSET 10",
"oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
}, },
) )
@ -1197,7 +1208,7 @@ class TestDialect(Validator):
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
"redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))", "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 VARCHAR(MAX), c2 VARCHAR(1024))",
}, },
) )

View file

@ -356,6 +356,30 @@ class TestHive(Validator):
"spark": "SELECT a_b AS 1_a FROM test_table", "spark": "SELECT a_b AS 1_a FROM test_table",
}, },
) )
self.validate_all(
"SELECT 1a_1a FROM test_a",
write={
"spark": "SELECT 1a_1a FROM test_a",
},
)
self.validate_all(
"SELECT 1a AS 1a_1a FROM test_a",
write={
"spark": "SELECT 1a AS 1a_1a FROM test_a",
},
)
self.validate_all(
"CREATE TABLE test_table (1a STRING)",
write={
"spark": "CREATE TABLE test_table (1a STRING)",
},
)
self.validate_all(
"CREATE TABLE test_table2 (1a_1a STRING)",
write={
"spark": "CREATE TABLE test_table2 (1a_1a STRING)",
},
)
self.validate_all( self.validate_all(
"PERCENTILE(x, 0.5)", "PERCENTILE(x, 0.5)",
write={ write={
@ -420,7 +444,7 @@ class TestHive(Validator):
"LOCATE('a', x, 3)", "LOCATE('a', x, 3)",
write={ write={
"duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "presto": "STRPOS(x, 'a', 3)",
"hive": "LOCATE('a', x, 3)", "hive": "LOCATE('a', x, 3)",
"spark": "LOCATE('a', x, 3)", "spark": "LOCATE('a', x, 3)",
}, },

View file

@ -65,6 +65,17 @@ class TestMySQL(Validator):
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
self.validate_identity("SELECT SCHEMA()") self.validate_identity("SELECT SCHEMA()")
def test_types(self):
self.validate_all(
"CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
read={
"mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
},
write={
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
},
)
def test_canonical_functions(self): def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")

View file

@ -46,14 +46,6 @@ class TestPostgres(Validator):
" CONSTRAINT valid_discount CHECK (price > discounted_price))" " CONSTRAINT valid_discount CHECK (price > discounted_price))"
}, },
) )
self.validate_all(
"CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
)
self.validate_all(
"CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
)
with self.assertRaises(ParseError): with self.assertRaises(ParseError):
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")

View file

@ -152,6 +152,10 @@ class TestPresto(Validator):
"spark": "FROM_UNIXTIME(x)", "spark": "FROM_UNIXTIME(x)",
}, },
) )
self.validate_identity("FROM_UNIXTIME(a, b)")
self.validate_identity("FROM_UNIXTIME(a, b, c)")
self.validate_identity("TRIM(a, b)")
self.validate_identity("VAR_POP(a)")
self.validate_all( self.validate_all(
"TO_UNIXTIME(x)", "TO_UNIXTIME(x)",
write={ write={
@ -302,6 +306,7 @@ class TestPresto(Validator):
) )
def test_presto(self): def test_presto(self):
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
self.validate_all( self.validate_all(
'SELECT a."b" FROM "foo"', 'SELECT a."b" FROM "foo"',
write={ write={
@ -443,8 +448,10 @@ class TestPresto(Validator):
"spark": UnsupportedError, "spark": UnsupportedError,
}, },
) )
self.validate_identity("SELECT * FROM (VALUES (1))")
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
def test_encode_decode(self): def test_encode_decode(self):
self.validate_all( self.validate_all(
@ -459,6 +466,12 @@ class TestPresto(Validator):
"spark": "DECODE(x, 'utf-8')", "spark": "DECODE(x, 'utf-8')",
}, },
) )
self.validate_all(
"FROM_UTF8(x, y)",
write={
"presto": "FROM_UTF8(x, y)",
},
)
self.validate_all( self.validate_all(
"ENCODE(x, 'utf-8')", "ENCODE(x, 'utf-8')",
write={ write={

View file

@ -89,7 +89,9 @@ class TestRedshift(Validator):
self.validate_identity( self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
) )
self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL") self.validate_identity(
"CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL"
)
self.validate_identity( self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
) )
@ -102,3 +104,81 @@ class TestRedshift(Validator):
self.validate_identity( self.validate_identity(
"CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
) )
def test_values(self):
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t",
},
)
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
},
)
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)",
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t",
},
)
self.validate_all(
"INSERT INTO t(a) VALUES (1), (2), (3)",
write={
"redshift": "INSERT INTO t (a) VALUES (1), (2), (3)",
},
)
self.validate_all(
"INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
write={
"redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
},
)
self.validate_all(
"INSERT INTO t(a, b) VALUES (1, 2), (3, 4)",
write={
"redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)",
},
)
def test_create_table_like(self):
self.validate_all(
"CREATE TABLE t1 LIKE t2",
write={
"redshift": "CREATE TABLE t1 (LIKE t2)",
},
)
self.validate_all(
"CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
write={
"redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
},
)
def test_rename_table(self):
self.validate_all(
"ALTER TABLE db.t1 RENAME TO db.t2",
write={
"spark": "ALTER TABLE db.t1 RENAME TO db.t2",
"redshift": "ALTER TABLE db.t1 RENAME TO t2",
},
)
def test_varchar_max(self):
self.validate_all(
"CREATE TABLE TEST (cola VARCHAR(MAX))",
write={
"redshift": 'CREATE TABLE "TEST" ("cola" VARCHAR(MAX))',
},
identify=True,
)
def test_no_schema_binding(self):
self.validate_all(
"CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
write={
"redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
},
)

View file

@ -307,5 +307,12 @@ TBLPROPERTIES (
def test_iif(self): def test_iif(self):
self.validate_all( self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} "SELECT IIF(cond, 'True', 'False')",
write={"spark": "SELECT IF(cond, 'True', 'False')"},
)
def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
) )

View file

@ -0,0 +1,23 @@
from tests.dialects.test_dialect import Validator
class TestTeradata(Validator):
dialect = "teradata"
def test_translate(self):
self.validate_all(
"TRANSLATE(x USING LATIN_TO_UNICODE)",
write={
"teradata": "CAST(x AS CHAR CHARACTER SET UNICODE)",
},
)
self.validate_identity("CAST(x AS CHAR CHARACTER SET UNICODE)")
def test_update(self):
self.validate_all(
"UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
write={
"teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
},
)

View file

@ -5,6 +5,13 @@ class TestTSQL(Validator):
dialect = "tsql" dialect = "tsql"
def test_tsql(self): def test_tsql(self):
self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
self.validate_identity("PRINT @TestVariable")
self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)")
self.validate_identity(
"SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID"
)
self.validate_identity('SELECT "x"."y" FROM foo') self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo") self.validate_identity("SELECT * FROM ##foo")

View file

@ -59,6 +59,8 @@ map.x
SELECT call.x SELECT call.x
a.b.INT(1.234) a.b.INT(1.234)
INT(x / 100) INT(x / 100)
time * 100
int * 100
x IN (-1, 1) x IN (-1, 1)
x IN ('a', 'a''a') x IN ('a', 'a''a')
x IN ((1)) x IN ((1))
@ -69,6 +71,11 @@ x IS TRUE
x IS FALSE x IS FALSE
x IS TRUE IS TRUE x IS TRUE IS TRUE
x LIKE y IS TRUE x LIKE y IS TRUE
MAP()
GREATEST(x)
LEAST(y)
MAX(a, b)
MIN(a, b)
time time
zone zone
ARRAY<TEXT> ARRAY<TEXT>
@ -133,6 +140,7 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1 SET x = 1
SET -v SET -v
SET x = ';'
COMMIT COMMIT
USE db USE db
NOT 1 NOT 1
@ -170,6 +178,7 @@ SELECT COUNT(DISTINCT a, b)
SELECT COUNT(DISTINCT a, b + 1) SELECT COUNT(DISTINCT a, b + 1)
SELECT SUM(DISTINCT x) SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x SELECT SUM(x IGNORE NULLS) AS x
SELECT TRUNCATE(a, b)
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x
@ -622,7 +631,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
SELECT x AS INTO FROM bla SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event SELECT * INTO newevent FROM event
@ -643,3 +652,21 @@ ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
SELECT div.a FROM test_table AS div SELECT div.a FROM test_table AS div
WITH view AS (SELECT 1 AS x) SELECT * FROM view
CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
ARRAY<STRUCT<INT, DOUBLE, ARRAY<INT>>>
ARRAY<INT>[1, 2, 3]
ARRAY<INT>[]
STRUCT<x VARCHAR(10)>
STRUCT<x VARCHAR(10)>("bla")
STRUCT<VARCHAR(10)>("bla")
STRUCT<INT>(5)
STRUCT<DATE>("2011-05-05")
STRUCT<x INT, y TEXT>(1, t.str_col)
SELECT CAST(NULL AS ARRAY<INT>) IS NULL AS array_is_null
CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)
CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))

View file

@ -322,3 +322,23 @@ SELECT
* /* multi * /* multi
line line
comment */; comment */;
WITH table_data AS (
SELECT 'bob' AS name, ARRAY['banana', 'apple', 'orange'] AS fruit_basket
)
SELECT
name,
fruit,
basket_index
FROM table_data
CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index;
WITH table_data AS (
SELECT
'bob' AS name,
ARRAY('banana', 'apple', 'orange') AS fruit_basket
)
SELECT
name,
fruit,
basket_index
FROM table_data
CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index;

View file

@ -624,6 +624,10 @@ FROM foo""",
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
exp.to_table(1) exp.to_table(1)
empty_string = exp.to_table("")
self.assertEqual(empty_string.name, "")
self.assertIsNone(table_only.args.get("db"))
self.assertIsNone(table_only.args.get("catalog"))
def test_to_column(self): def test_to_column(self):
column_only = exp.to_column("column_name") column_only = exp.to_column("column_name")
@ -715,3 +719,9 @@ FROM foo""",
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT") self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL") self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN") self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
def test_rename_table(self):
self.assertEqual(
exp.rename_table("t1", "t2").sql(),
"ALTER TABLE t1 RENAME TO t2",
)

View file

@ -6,7 +6,7 @@ from pandas.testing import assert_frame_equal
import sqlglot import sqlglot
from sqlglot import exp, optimizer, parse_one from sqlglot import exp, optimizer, parse_one
from sqlglot.errors import OptimizeError from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
@ -161,7 +161,7 @@ class TestOptimizer(unittest.TestCase):
def test_qualify_columns__invalid(self): def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql): with self.subTest(sql):
with self.assertRaises(OptimizeError): with self.assertRaises((OptimizeError, SchemaError)):
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_lower_identities(self): def test_lower_identities(self):

View file

@ -325,3 +325,9 @@ class TestParser(unittest.TestCase):
"Expected table name", "Expected table name",
logger, logger,
) )
def test_rename_table(self):
self.assertEqual(
parse_one("ALTER TABLE foo RENAME TO bar").sql(),
"ALTER TABLE foo RENAME TO bar",
)

View file

@ -272,6 +272,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2", "WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2", "WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
) )
self.validate(
"SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)",
"SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)",
write="presto",
)
def test_alter(self): def test_alter(self):
self.validate( self.validate(
@ -447,6 +452,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
self.assertEqual(generated, pretty) self.assertEqual(generated, pretty)
self.assertEqual(parse_one(sql), parse_one(pretty)) self.assertEqual(parse_one(sql), parse_one(pretty))
def test_pretty_line_breaks(self):
self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'")
@mock.patch("sqlglot.parser.logger") @mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger): def test_error_level(self, logger):
invalid = "x + 1. (" invalid = "x + 1. ("