1
0
Fork 0

Merging upstream version 6.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:31:47 +01:00
parent 0822fbed3a
commit 9bc11b290e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 312 additions and 45 deletions

View file

@ -1,6 +1,23 @@
Changelog
=========
v6.2.0
------
Changes:
- New: TSQL support
- Breaking: Removed $ from tokenizer, added @ placeholders
- Improvement: Nodes can now be removed in transform and replace [8cd81c3](https://github.com/tobymao/sqlglot/commit/8cd81c36561463b9849a8e0c2d70248c5b1feb62)
- Improvement: Snowflake timestamp support
- Improvement: Property conversion for CTAS Builder
- Improvement: Tokenizers are now unique per dialect instance
v6.1.0
------

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.1.1"
__version__ = "6.2.0"
pretty = False

View file

@ -14,3 +14,4 @@ from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL

View file

@ -27,6 +27,7 @@ class Dialects(str, Enum):
STARROCKS = "starrocks"
TABLEAU = "tableau"
TRINO = "trino"
TSQL = "tsql"
class _Dialect(type):
@ -53,7 +54,6 @@ class _Dialect(type):
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
@ -95,7 +95,6 @@ class Dialect(metaclass=_Dialect):
tokenizer_class = None
parser_class = None
generator_class = None
tokenizer = None
@classmethod
def get_or_raise(cls, dialect):
@ -138,6 +137,12 @@ class Dialect(metaclass=_Dialect):
def transpile(self, code, **opts):
return self.generate(self.parse(code), **opts)
@property
def tokenizer(self):
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts):
return self.parser_class(
**{
@ -170,7 +175,15 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
def _rename(self, expression):
args = (
self.expressions(expression, flat=True)
if isinstance(expression, exp.Func) and expression.is_var_len_args
else csv(*[self.sql(e) for e in expression.args.values()])
)
return f"{name}({args})"
return _rename
def approx_count_distinct_sql(self, expression):

View file

@ -108,7 +108,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})",
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),

View file

@ -106,6 +106,11 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_extract(),
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@ -118,10 +123,20 @@ class Snowflake(Dialect):
class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
"$": TokenType.DOLLAR, # needed to break for quotes
}
KEYWORDS = {
**Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
}
class Generator(Generator):
@ -132,6 +147,11 @@ class Snowflake(Dialect):
exp.UnixToTime: _unix_to_time,
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -82,6 +82,7 @@ class Spark(Hive):
TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

38
sqlglot/dialects/tsql.py Normal file
View file

@ -0,0 +1,38 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.tokens import Tokenizer, TokenType
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
}

View file

@ -1,4 +1,5 @@
import inspect
import numbers
import re
import sys
from collections import deque
@ -6,7 +7,7 @@ from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
class _Expression(type):
@ -350,7 +351,8 @@ class Expression(metaclass=_Expression):
Args:
fun (function): a function which takes a node as an argument and returns a
new transformed node or the same node without modifications.
new transformed node or the same node without modifications. If the function
returns None, then the corresponding node will be removed from the syntax tree.
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
modified in place.
@ -360,9 +362,7 @@ class Expression(metaclass=_Expression):
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
if new_node is None:
raise ValueError("A transformed node cannot be None")
if not isinstance(new_node, Expression):
if new_node is None or not isinstance(new_node, Expression):
return new_node
if new_node is not node:
new_node.parent = node.parent
@ -843,10 +843,6 @@ class Ordered(Expression):
arg_types = {"this": True, "desc": True, "nulls_first": True}
class Properties(Expression):
arg_types = {"expressions": True}
class Property(Expression):
arg_types = {"this": True, "value": True}
@ -891,6 +887,42 @@ class AnonymousProperty(Property):
pass
class Properties(Expression):
arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@classmethod
def from_dict(cls, properties_dict):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if isinstance(value, Expression):
return value
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[_convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression):
pass
@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression):
)
properties_expression = None
if properties:
properties_str = " ".join(
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
into=Properties,
dialect=dialect,
**opts,
)
properties_expression = Properties.from_dict(properties)
return Create(
this=table_expression,
@ -1650,6 +1674,10 @@ class Star(Expression):
return "*"
class Parameter(Expression):
pass
class Placeholder(Expression):
arg_types = {}
@ -1688,6 +1716,7 @@ class DataType(Expression):
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATE = auto()
DATETIME = auto()
ARRAY = auto()
@ -1702,6 +1731,13 @@ class DataType(Expression):
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@ -2976,7 +3012,7 @@ def replace_children(expression, fun):
else:
new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
def column_table_names(expression):

View file

@ -748,6 +748,9 @@ class Generator:
def structkwarg_sql(self, expression):
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def placeholder_sql(self, *_):
return "?"
@ -903,7 +906,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")

View file

@ -81,6 +81,7 @@ class Parser:
TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
TokenType.DATE,
TokenType.DECIMAL,
@ -92,6 +93,13 @@ class Parser:
TokenType.SERIAL,
TokenType.SMALLSERIAL,
TokenType.BIGSERIAL,
TokenType.XML,
TokenType.UNIQUEIDENTIFIER,
TokenType.MONEY,
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.SQL_VARIANT,
*NESTED_TYPE_TOKENS,
}
@ -233,6 +241,7 @@ class Parser:
TIMESTAMPS = {
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
}
SET_OPERATIONS = {
@ -315,6 +324,7 @@ class Parser:
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression(
@ -1497,12 +1507,19 @@ class Parser:
if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE)
if tz:
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
if ltz:
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
self._match(TokenType.WITHOUT_TIME_ZONE)
return exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
@ -1845,8 +1862,11 @@ class Parser:
def _parse_extract(self):
this = self._parse_var() or self._parse_type()
if not self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev)
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
if not self._match(TokenType.COMMA):
self.raise_error("Expected FROM or comma after EXTRACT", self._prev)
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())

View file

@ -41,6 +41,7 @@ class TokenType(AutoName):
LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
SPACE = auto()
BREAK = auto()
@ -75,6 +76,7 @@ class TokenType(AutoName):
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATETIME = auto()
DATE = auto()
UUID = auto()
@ -86,6 +88,13 @@ class TokenType(AutoName):
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
# keywords
ADD_FILE = auto()
@ -247,6 +256,7 @@ class TokenType(AutoName):
WINDOW = auto()
WITH = auto()
WITH_TIME_ZONE = auto()
WITH_LOCAL_TIME_ZONE = auto()
WITHIN_GROUP = auto()
WITHOUT_TIME_ZONE = auto()
UNIQUE = auto()
@ -340,7 +350,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION,
"$": TokenType.DOLLAR,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
@ -520,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer):
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"ARRAY": TokenType.ARRAY,
@ -561,6 +572,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.BINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
"DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE,

View file

@ -228,6 +228,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
},
)
@ -237,6 +238,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
},
)
@ -246,6 +248,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
"redshift": "TO_TIMESTAMP(x, 'YY')",
"spark": "TO_TIMESTAMP(x, 'yy')",
},
)
@ -287,6 +290,7 @@ class TestDialect(Validator):
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
},
)
self.validate_all(
@ -295,6 +299,7 @@ class TestDialect(Validator):
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
"redshift": "CAST(x AS TEXT)",
},
)
self.validate_all(

View file

@ -66,6 +66,9 @@ class TestDuckDB(Validator):
def test_duckdb(self):
self.validate_all(
"LIST_VALUE(0, 1, 2)",
read={
"spark": "ARRAY(0, 1, 2)",
},
write={
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",

View file

@ -131,7 +131,7 @@ class TestHive(Validator):
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
},
)
self.validate_all(

View file

@ -0,0 +1,6 @@
from tests.dialects.test_dialect import Validator
class TestOracle(Validator):
def test_oracle(self):
self.validate_identity("SELECT * FROM V$SESSION")

View file

@ -173,7 +173,7 @@ class TestPresto(Validator):
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all(
@ -181,7 +181,7 @@ class TestPresto(Validator):
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
},
)
self.validate_all(

View file

@ -175,3 +175,48 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
},
)
def test_timestamps(self):
self.validate_all(
"SELECT CAST(a AS TIMESTAMP)",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)",
},
)
self.validate_all(
"SELECT a::TIMESTAMP_LTZ(9)",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ(9))",
},
)
self.validate_all(
"SELECT a::TIMESTAMPLTZ",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
},
)
self.validate_all(
"SELECT a::TIMESTAMP WITH LOCAL TIME ZONE",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
},
)
self.validate_identity("SELECT EXTRACT(month FROM a)")
self.validate_all(
"SELECT EXTRACT('month', a)",
write={
"snowflake": "SELECT EXTRACT('month' FROM a)",
},
)
self.validate_all(
"SELECT DATE_PART('month', a)",
write={
"snowflake": "SELECT EXTRACT('month' FROM a)",
},
)
self.validate_all(
"SELECT DATE_PART(month FROM a::DATETIME)",
write={
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
},
)

View file

@ -44,15 +44,7 @@ class TestSpark(Validator):
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all(
@ -86,7 +78,7 @@ COMMENT 'Test comment: blah'
PARTITIONED BY (
date STRING
)
STORED AS ICEBERG
USING ICEBERG
TBLPROPERTIES (
'x' = '1'
)""",

View file

@ -0,0 +1,26 @@
from tests.dialects.test_dialect import Validator
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_all(
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
write={
"tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo',
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)
def test_types(self):
self.validate_identity("CAST(x AS XML)")
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
self.validate_identity("CAST(x AS MONEY)")
self.validate_identity("CAST(x AS SMALLMONEY)")
self.validate_identity("CAST(x AS ROWVERSION)")
self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)")

View file

@ -224,9 +224,6 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
self.assertIs(actual_expression_2, expression)
with self.assertRaises(ValueError):
parse_one("a").transform(lambda n: None)
def test_transform_no_infinite_recursion(self):
expression = parse_one("a")
@ -247,6 +244,35 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x")
def test_transform_node_removal(self):
expression = parse_one("SELECT a, b FROM x")
def remove_column_b(node):
if isinstance(node, exp.Column) and node.name == "b":
return None
return node
self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x")
self.assertEqual(expression.transform(lambda _: None), None)
expression = parse_one("CAST(x AS FLOAT)")
def remove_non_list_arg(node):
if isinstance(node, exp.DataType):
return None
return node
self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )")
expression = parse_one("SELECT a, b FROM x")
def remove_all_columns(node):
if isinstance(node, exp.Column):
return None
return node
self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x")
def test_replace(self):
expression = parse_one("SELECT a, b FROM x")
expression.find(exp.Column).replace(parse_one("c"))

View file

@ -114,6 +114,9 @@ class TestParser(unittest.TestCase):
with self.assertRaises(ParseError):
parse_one("SELECT FROM x ORDER BY")
def test_parameter(self):
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
def test_annotations(self):
expression = parse_one(
"""