1
0
Fork 0

Merging upstream version 6.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:31:47 +01:00
parent 0822fbed3a
commit 9bc11b290e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 312 additions and 45 deletions

View file

@ -1,6 +1,23 @@
Changelog Changelog
========= =========
v6.2.0
------
Changes:
- New: TSQL support
- Breaking: Removed $ from tokenizer, added @ placeholders
- Improvement: Nodes can now be removed in transform and replace [8cd81c3](https://github.com/tobymao/sqlglot/commit/8cd81c36561463b9849a8e0c2d70248c5b1feb62)
- Improvement: Snowflake timestamp support
- Improvement: Property conversion for CTAS Builder
- Improvement: Tokenizers are now unique per dialect instance
v6.1.0 v6.1.0
------ ------

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.1.1" __version__ = "6.2.0"
pretty = False pretty = False

View file

@ -14,3 +14,4 @@ from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.trino import Trino from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL

View file

@ -27,6 +27,7 @@ class Dialects(str, Enum):
STARROCKS = "starrocks" STARROCKS = "starrocks"
TABLEAU = "tableau" TABLEAU = "tableau"
TRINO = "trino" TRINO = "trino"
TSQL = "tsql"
class _Dialect(type): class _Dialect(type):
@ -53,7 +54,6 @@ class _Dialect(type):
klass.parser_class = getattr(klass, "Parser", Parser) klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator) klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
@ -95,7 +95,6 @@ class Dialect(metaclass=_Dialect):
tokenizer_class = None tokenizer_class = None
parser_class = None parser_class = None
generator_class = None generator_class = None
tokenizer = None
@classmethod @classmethod
def get_or_raise(cls, dialect): def get_or_raise(cls, dialect):
@ -138,6 +137,12 @@ class Dialect(metaclass=_Dialect):
def transpile(self, code, **opts): def transpile(self, code, **opts):
return self.generate(self.parse(code), **opts) return self.generate(self.parse(code), **opts)
@property
def tokenizer(self):
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts): def parser(self, **opts):
return self.parser_class( return self.parser_class(
**{ **{
@ -170,7 +175,15 @@ class Dialect(metaclass=_Dialect):
def rename_func(name): def rename_func(name):
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" def _rename(self, expression):
args = (
self.expressions(expression, flat=True)
if isinstance(expression, exp.Func) and expression.is_var_len_args
else csv(*[self.sql(e) for e in expression.args.values()])
)
return f"{name}({args})"
return _rename
def approx_count_distinct_sql(self, expression): def approx_count_distinct_sql(self, expression):

View file

@ -108,7 +108,7 @@ class DuckDB(Dialect):
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})", exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql, exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"), exp.ArraySum: rename_func("LIST_SUM"),

View file

@ -106,6 +106,11 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp, "TO_TIMESTAMP": _snowflake_to_timestamp,
} }
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_extract(),
}
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS, **Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression( TokenType.COLON: lambda self, this, path: self.expression(
@ -118,10 +123,20 @@ class Snowflake(Dialect):
class Tokenizer(Tokenizer): class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
ESCAPE = "\\" ESCAPE = "\\"
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
"$": TokenType.DOLLAR, # needed to break for quotes
}
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY, "QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE, "DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
} }
class Generator(Generator): class Generator(Generator):
@ -132,6 +147,11 @@ class Snowflake(Dialect):
exp.UnixToTime: _unix_to_time, exp.UnixToTime: _unix_to_time,
} }
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
def except_op(self, expression): def except_op(self, expression):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake") self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -82,6 +82,7 @@ class Spark(Hive):
TRANSFORMS = { TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}}, **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

38
sqlglot/dialects/tsql.py Normal file
View file

@ -0,0 +1,38 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.tokens import Tokenizer, TokenType
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
}

View file

@ -1,4 +1,5 @@
import inspect import inspect
import numbers
import re import re
import sys import sys
from collections import deque from collections import deque
@ -6,7 +7,7 @@ from copy import deepcopy
from enum import auto from enum import auto
from sqlglot.errors import ParseError from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
class _Expression(type): class _Expression(type):
@ -350,7 +351,8 @@ class Expression(metaclass=_Expression):
Args: Args:
fun (function): a function which takes a node as an argument and returns a fun (function): a function which takes a node as an argument and returns a
new transformed node or the same node without modifications. new transformed node or the same node without modifications. If the function
returns None, then the corresponding node will be removed from the syntax tree.
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
modified in place. modified in place.
@ -360,9 +362,7 @@ class Expression(metaclass=_Expression):
node = self.copy() if copy else self node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs) new_node = fun(node, *args, **kwargs)
if new_node is None: if new_node is None or not isinstance(new_node, Expression):
raise ValueError("A transformed node cannot be None")
if not isinstance(new_node, Expression):
return new_node return new_node
if new_node is not node: if new_node is not node:
new_node.parent = node.parent new_node.parent = node.parent
@ -843,10 +843,6 @@ class Ordered(Expression):
arg_types = {"this": True, "desc": True, "nulls_first": True} arg_types = {"this": True, "desc": True, "nulls_first": True}
class Properties(Expression):
arg_types = {"expressions": True}
class Property(Expression): class Property(Expression):
arg_types = {"this": True, "value": True} arg_types = {"this": True, "value": True}
@ -891,6 +887,42 @@ class AnonymousProperty(Property):
pass pass
class Properties(Expression):
arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@classmethod
def from_dict(cls, properties_dict):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if isinstance(value, Expression):
return value
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[_convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression): class Qualify(Expression):
pass pass
@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression):
) )
properties_expression = None properties_expression = None
if properties: if properties:
properties_str = " ".join( properties_expression = Properties.from_dict(properties)
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
into=Properties,
dialect=dialect,
**opts,
)
return Create( return Create(
this=table_expression, this=table_expression,
@ -1650,6 +1674,10 @@ class Star(Expression):
return "*" return "*"
class Parameter(Expression):
pass
class Placeholder(Expression): class Placeholder(Expression):
arg_types = {} arg_types = {}
@ -1688,6 +1716,7 @@ class DataType(Expression):
INTERVAL = auto() INTERVAL = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATE = auto() DATE = auto()
DATETIME = auto() DATETIME = auto()
ARRAY = auto() ARRAY = auto()
@ -1702,6 +1731,13 @@ class DataType(Expression):
SERIAL = auto() SERIAL = auto()
SMALLSERIAL = auto() SMALLSERIAL = auto()
BIGSERIAL = auto() BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
@classmethod @classmethod
def build(cls, dtype, **kwargs): def build(cls, dtype, **kwargs):
@ -2976,7 +3012,7 @@ def replace_children(expression, fun):
else: else:
new_child_nodes.append(cn) new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
def column_table_names(expression): def column_table_names(expression):

View file

@ -748,6 +748,9 @@ class Generator:
def structkwarg_sql(self, expression): def structkwarg_sql(self, expression):
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def placeholder_sql(self, *_): def placeholder_sql(self, *_):
return "?" return "?"
@ -903,7 +906,7 @@ class Generator:
return f"UNIQUE ({columns})" return f"UNIQUE ({columns})"
def if_sql(self, expression): def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
def in_sql(self, expression): def in_sql(self, expression):
query = expression.args.get("query") query = expression.args.get("query")

View file

@ -81,6 +81,7 @@ class Parser:
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME, TokenType.DATETIME,
TokenType.DATE, TokenType.DATE,
TokenType.DECIMAL, TokenType.DECIMAL,
@ -92,6 +93,13 @@ class Parser:
TokenType.SERIAL, TokenType.SERIAL,
TokenType.SMALLSERIAL, TokenType.SMALLSERIAL,
TokenType.BIGSERIAL, TokenType.BIGSERIAL,
TokenType.XML,
TokenType.UNIQUEIDENTIFIER,
TokenType.MONEY,
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.SQL_VARIANT,
*NESTED_TYPE_TOKENS, *NESTED_TYPE_TOKENS,
} }
@ -233,6 +241,7 @@ class Parser:
TIMESTAMPS = { TIMESTAMPS = {
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
} }
SET_OPERATIONS = { SET_OPERATIONS = {
@ -315,6 +324,7 @@ class Parser:
TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression( TokenType.INTRODUCER: lambda self, token: self.expression(
@ -1497,12 +1507,19 @@ class Parser:
if type_token in self.TIMESTAMPS: if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE)
if tz: if tz:
return exp.DataType( return exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ, this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions, expressions=expressions,
) )
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
if ltz:
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
self._match(TokenType.WITHOUT_TIME_ZONE)
return exp.DataType( return exp.DataType(
this=exp.DataType.Type.TIMESTAMP, this=exp.DataType.Type.TIMESTAMP,
expressions=expressions, expressions=expressions,
@ -1845,8 +1862,11 @@ class Parser:
def _parse_extract(self): def _parse_extract(self):
this = self._parse_var() or self._parse_type() this = self._parse_var() or self._parse_type()
if not self._match(TokenType.FROM): if self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev) return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
if not self._match(TokenType.COMMA):
self.raise_error("Expected FROM or comma after EXTRACT", self._prev)
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())

View file

@ -41,6 +41,7 @@ class TokenType(AutoName):
LR_ARROW = auto() LR_ARROW = auto()
ANNOTATION = auto() ANNOTATION = auto()
DOLLAR = auto() DOLLAR = auto()
PARAMETER = auto()
SPACE = auto() SPACE = auto()
BREAK = auto() BREAK = auto()
@ -75,6 +76,7 @@ class TokenType(AutoName):
JSON = auto() JSON = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATETIME = auto() DATETIME = auto()
DATE = auto() DATE = auto()
UUID = auto() UUID = auto()
@ -86,6 +88,13 @@ class TokenType(AutoName):
SERIAL = auto() SERIAL = auto()
SMALLSERIAL = auto() SMALLSERIAL = auto()
BIGSERIAL = auto() BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
# keywords # keywords
ADD_FILE = auto() ADD_FILE = auto()
@ -247,6 +256,7 @@ class TokenType(AutoName):
WINDOW = auto() WINDOW = auto()
WITH = auto() WITH = auto()
WITH_TIME_ZONE = auto() WITH_TIME_ZONE = auto()
WITH_LOCAL_TIME_ZONE = auto()
WITHIN_GROUP = auto() WITHIN_GROUP = auto()
WITHOUT_TIME_ZONE = auto() WITHOUT_TIME_ZONE = auto()
UNIQUE = auto() UNIQUE = auto()
@ -340,7 +350,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA, "~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER, "?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION, "#": TokenType.ANNOTATION,
"$": TokenType.DOLLAR, "@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else # used for breaking a var like x'y' but nothing else
# the token type doesn't matter # the token type doesn't matter
"'": TokenType.QUOTE, "'": TokenType.QUOTE,
@ -520,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer):
"WHERE": TokenType.WHERE, "WHERE": TokenType.WHERE,
"WITH": TokenType.WITH, "WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE, "WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP, "WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"ARRAY": TokenType.ARRAY, "ARRAY": TokenType.ARRAY,
@ -561,6 +572,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.BINARY, "BYTEA": TokenType.BINARY,
"TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
"DATE": TokenType.DATE, "DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME, "DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE, "UNIQUE": TokenType.UNIQUE,

View file

@ -228,6 +228,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
}, },
) )
@ -237,6 +238,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
}, },
) )
@ -246,6 +248,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%y')", "duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')", "presto": "DATE_PARSE(x, '%y')",
"redshift": "TO_TIMESTAMP(x, 'YY')",
"spark": "TO_TIMESTAMP(x, 'yy')", "spark": "TO_TIMESTAMP(x, 'yy')",
}, },
) )
@ -287,6 +290,7 @@ class TestDialect(Validator):
"duckdb": "STRFTIME(x, '%Y-%m-%d')", "duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')",
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
}, },
) )
self.validate_all( self.validate_all(
@ -295,6 +299,7 @@ class TestDialect(Validator):
"duckdb": "CAST(x AS TEXT)", "duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)", "hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)", "presto": "CAST(x AS VARCHAR)",
"redshift": "CAST(x AS TEXT)",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -66,6 +66,9 @@ class TestDuckDB(Validator):
def test_duckdb(self): def test_duckdb(self):
self.validate_all( self.validate_all(
"LIST_VALUE(0, 1, 2)", "LIST_VALUE(0, 1, 2)",
read={
"spark": "ARRAY(0, 1, 2)",
},
write={ write={
"bigquery": "[0, 1, 2]", "bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)", "duckdb": "LIST_VALUE(0, 1, 2)",

View file

@ -131,7 +131,7 @@ class TestHive(Validator):
write={ write={
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -0,0 +1,6 @@
from tests.dialects.test_dialect import Validator
class TestOracle(Validator):
def test_oracle(self):
self.validate_identity("SELECT * FROM V$SESSION")

View file

@ -173,7 +173,7 @@ class TestPresto(Validator):
write={ write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
}, },
) )
self.validate_all( self.validate_all(
@ -181,7 +181,7 @@ class TestPresto(Validator):
write={ write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
}, },
) )
self.validate_all( self.validate_all(

View file

@ -175,3 +175,48 @@ class TestSnowflake(Validator):
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
}, },
) )
def test_timestamps(self):
self.validate_all(
"SELECT CAST(a AS TIMESTAMP)",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)",
},
)
self.validate_all(
"SELECT a::TIMESTAMP_LTZ(9)",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ(9))",
},
)
self.validate_all(
"SELECT a::TIMESTAMPLTZ",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
},
)
self.validate_all(
"SELECT a::TIMESTAMP WITH LOCAL TIME ZONE",
write={
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
},
)
self.validate_identity("SELECT EXTRACT(month FROM a)")
self.validate_all(
"SELECT EXTRACT('month', a)",
write={
"snowflake": "SELECT EXTRACT('month' FROM a)",
},
)
self.validate_all(
"SELECT DATE_PART('month', a)",
write={
"snowflake": "SELECT EXTRACT('month' FROM a)",
},
)
self.validate_all(
"SELECT DATE_PART(month FROM a::DATETIME)",
write={
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
},
)

View file

@ -44,15 +44,7 @@ class TestSpark(Validator):
write={ write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
}, },
) )
self.validate_all( self.validate_all(
@ -86,7 +78,7 @@ COMMENT 'Test comment: blah'
PARTITIONED BY ( PARTITIONED BY (
date STRING date STRING
) )
STORED AS ICEBERG USING ICEBERG
TBLPROPERTIES ( TBLPROPERTIES (
'x' = '1' 'x' = '1'
)""", )""",

View file

@ -0,0 +1,26 @@
from tests.dialects.test_dialect import Validator
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_all(
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
write={
"tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo',
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)
def test_types(self):
self.validate_identity("CAST(x AS XML)")
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
self.validate_identity("CAST(x AS MONEY)")
self.validate_identity("CAST(x AS SMALLMONEY)")
self.validate_identity("CAST(x AS ROWVERSION)")
self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)")

View file

@ -224,9 +224,6 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
self.assertIs(actual_expression_2, expression) self.assertIs(actual_expression_2, expression)
with self.assertRaises(ValueError):
parse_one("a").transform(lambda n: None)
def test_transform_no_infinite_recursion(self): def test_transform_no_infinite_recursion(self):
expression = parse_one("a") expression = parse_one("a")
@ -247,6 +244,35 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x")
def test_transform_node_removal(self):
expression = parse_one("SELECT a, b FROM x")
def remove_column_b(node):
if isinstance(node, exp.Column) and node.name == "b":
return None
return node
self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x")
self.assertEqual(expression.transform(lambda _: None), None)
expression = parse_one("CAST(x AS FLOAT)")
def remove_non_list_arg(node):
if isinstance(node, exp.DataType):
return None
return node
self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )")
expression = parse_one("SELECT a, b FROM x")
def remove_all_columns(node):
if isinstance(node, exp.Column):
return None
return node
self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x")
def test_replace(self): def test_replace(self):
expression = parse_one("SELECT a, b FROM x") expression = parse_one("SELECT a, b FROM x")
expression.find(exp.Column).replace(parse_one("c")) expression.find(exp.Column).replace(parse_one("c"))

View file

@ -114,6 +114,9 @@ class TestParser(unittest.TestCase):
with self.assertRaises(ParseError): with self.assertRaises(ParseError):
parse_one("SELECT FROM x ORDER BY") parse_one("SELECT FROM x ORDER BY")
def test_parameter(self):
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
def test_annotations(self): def test_annotations(self):
expression = parse_one( expression = parse_one(
""" """