Adding upstream version 6.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
71430b22d0
commit
bf84d96bab
22 changed files with 312 additions and 45 deletions
17
CHANGELOG.md
17
CHANGELOG.md
|
@ -1,6 +1,23 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v6.2.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- New: TSQL support
|
||||||
|
|
||||||
|
- Breaking: Removed $ from tokenizer, added @ placeholders
|
||||||
|
|
||||||
|
- Improvement: Nodes can now be removed in transform and replace [8cd81c3](https://github.com/tobymao/sqlglot/commit/8cd81c36561463b9849a8e0c2d70248c5b1feb62)
|
||||||
|
|
||||||
|
- Improvement: Snowflake timestamp support
|
||||||
|
|
||||||
|
- Improvement: Property conversion for CTAS Builder
|
||||||
|
|
||||||
|
- Improvement: Tokenizers are now unique per dialect instance
|
||||||
|
|
||||||
v6.1.0
|
v6.1.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from sqlglot.generator import Generator
|
||||||
from sqlglot.parser import Parser
|
from sqlglot.parser import Parser
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "6.1.1"
|
__version__ = "6.2.0"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
|
|
||||||
|
|
|
@ -14,3 +14,4 @@ from sqlglot.dialects.sqlite import SQLite
|
||||||
from sqlglot.dialects.starrocks import StarRocks
|
from sqlglot.dialects.starrocks import StarRocks
|
||||||
from sqlglot.dialects.tableau import Tableau
|
from sqlglot.dialects.tableau import Tableau
|
||||||
from sqlglot.dialects.trino import Trino
|
from sqlglot.dialects.trino import Trino
|
||||||
|
from sqlglot.dialects.tsql import TSQL
|
||||||
|
|
|
@ -27,6 +27,7 @@ class Dialects(str, Enum):
|
||||||
STARROCKS = "starrocks"
|
STARROCKS = "starrocks"
|
||||||
TABLEAU = "tableau"
|
TABLEAU = "tableau"
|
||||||
TRINO = "trino"
|
TRINO = "trino"
|
||||||
|
TSQL = "tsql"
|
||||||
|
|
||||||
|
|
||||||
class _Dialect(type):
|
class _Dialect(type):
|
||||||
|
@ -53,7 +54,6 @@ class _Dialect(type):
|
||||||
klass.parser_class = getattr(klass, "Parser", Parser)
|
klass.parser_class = getattr(klass, "Parser", Parser)
|
||||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||||
|
|
||||||
klass.tokenizer = klass.tokenizer_class()
|
|
||||||
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||||
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
|
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
|
||||||
|
|
||||||
|
@ -95,7 +95,6 @@ class Dialect(metaclass=_Dialect):
|
||||||
tokenizer_class = None
|
tokenizer_class = None
|
||||||
parser_class = None
|
parser_class = None
|
||||||
generator_class = None
|
generator_class = None
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_or_raise(cls, dialect):
|
def get_or_raise(cls, dialect):
|
||||||
|
@ -138,6 +137,12 @@ class Dialect(metaclass=_Dialect):
|
||||||
def transpile(self, code, **opts):
|
def transpile(self, code, **opts):
|
||||||
return self.generate(self.parse(code), **opts)
|
return self.generate(self.parse(code), **opts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
if not hasattr(self, "_tokenizer"):
|
||||||
|
self._tokenizer = self.tokenizer_class()
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
def parser(self, **opts):
|
def parser(self, **opts):
|
||||||
return self.parser_class(
|
return self.parser_class(
|
||||||
**{
|
**{
|
||||||
|
@ -170,7 +175,15 @@ class Dialect(metaclass=_Dialect):
|
||||||
|
|
||||||
|
|
||||||
def rename_func(name):
|
def rename_func(name):
|
||||||
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
|
def _rename(self, expression):
|
||||||
|
args = (
|
||||||
|
self.expressions(expression, flat=True)
|
||||||
|
if isinstance(expression, exp.Func) and expression.is_var_len_args
|
||||||
|
else csv(*[self.sql(e) for e in expression.args.values()])
|
||||||
|
)
|
||||||
|
return f"{name}({args})"
|
||||||
|
|
||||||
|
return _rename
|
||||||
|
|
||||||
|
|
||||||
def approx_count_distinct_sql(self, expression):
|
def approx_count_distinct_sql(self, expression):
|
||||||
|
|
|
@ -108,7 +108,7 @@ class DuckDB(Dialect):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Generator.TRANSFORMS,
|
**Generator.TRANSFORMS,
|
||||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||||
exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})",
|
exp.Array: rename_func("LIST_VALUE"),
|
||||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||||
exp.ArraySort: _array_sort_sql,
|
exp.ArraySort: _array_sort_sql,
|
||||||
exp.ArraySum: rename_func("LIST_SUM"),
|
exp.ArraySum: rename_func("LIST_SUM"),
|
||||||
|
|
|
@ -106,6 +106,11 @@ class Snowflake(Dialect):
|
||||||
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
"TO_TIMESTAMP": _snowflake_to_timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FUNCTION_PARSERS = {
|
||||||
|
**Parser.FUNCTION_PARSERS,
|
||||||
|
"DATE_PART": lambda self: self._parse_extract(),
|
||||||
|
}
|
||||||
|
|
||||||
COLUMN_OPERATORS = {
|
COLUMN_OPERATORS = {
|
||||||
**Parser.COLUMN_OPERATORS,
|
**Parser.COLUMN_OPERATORS,
|
||||||
TokenType.COLON: lambda self, this, path: self.expression(
|
TokenType.COLON: lambda self, this, path: self.expression(
|
||||||
|
@ -118,10 +123,20 @@ class Snowflake(Dialect):
|
||||||
class Tokenizer(Tokenizer):
|
class Tokenizer(Tokenizer):
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
ESCAPE = "\\"
|
ESCAPE = "\\"
|
||||||
|
|
||||||
|
SINGLE_TOKENS = {
|
||||||
|
**Tokenizer.SINGLE_TOKENS,
|
||||||
|
"$": TokenType.DOLLAR, # needed to break for quotes
|
||||||
|
}
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Tokenizer.KEYWORDS,
|
**Tokenizer.KEYWORDS,
|
||||||
"QUALIFY": TokenType.QUALIFY,
|
"QUALIFY": TokenType.QUALIFY,
|
||||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||||
|
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||||
|
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||||
|
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||||
|
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(Generator):
|
class Generator(Generator):
|
||||||
|
@ -132,6 +147,11 @@ class Snowflake(Dialect):
|
||||||
exp.UnixToTime: _unix_to_time,
|
exp.UnixToTime: _unix_to_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
**Generator.TYPE_MAPPING,
|
||||||
|
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||||
|
}
|
||||||
|
|
||||||
def except_op(self, expression):
|
def except_op(self, expression):
|
||||||
if not expression.args.get("distinct", False):
|
if not expression.args.get("distinct", False):
|
||||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||||
|
|
|
@ -82,6 +82,7 @@ class Spark(Hive):
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
|
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
|
||||||
|
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
||||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||||
|
|
38
sqlglot/dialects/tsql.py
Normal file
38
sqlglot/dialects/tsql.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
from sqlglot import exp
|
||||||
|
from sqlglot.dialects.dialect import Dialect
|
||||||
|
from sqlglot.generator import Generator
|
||||||
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
|
|
||||||
|
class TSQL(Dialect):
|
||||||
|
null_ordering = "nulls_are_small"
|
||||||
|
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
||||||
|
|
||||||
|
class Tokenizer(Tokenizer):
|
||||||
|
IDENTIFIERS = ['"', ("[", "]")]
|
||||||
|
|
||||||
|
KEYWORDS = {
|
||||||
|
**Tokenizer.KEYWORDS,
|
||||||
|
"BIT": TokenType.BOOLEAN,
|
||||||
|
"REAL": TokenType.FLOAT,
|
||||||
|
"NTEXT": TokenType.TEXT,
|
||||||
|
"SMALLDATETIME": TokenType.DATETIME,
|
||||||
|
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
||||||
|
"TIME": TokenType.TIMESTAMP,
|
||||||
|
"VARBINARY": TokenType.BINARY,
|
||||||
|
"IMAGE": TokenType.IMAGE,
|
||||||
|
"MONEY": TokenType.MONEY,
|
||||||
|
"SMALLMONEY": TokenType.SMALLMONEY,
|
||||||
|
"ROWVERSION": TokenType.ROWVERSION,
|
||||||
|
"SQL_VARIANT": TokenType.SQL_VARIANT,
|
||||||
|
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||||
|
"XML": TokenType.XML,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Generator(Generator):
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
**Generator.TYPE_MAPPING,
|
||||||
|
exp.DataType.Type.BOOLEAN: "BIT",
|
||||||
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
|
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
import numbers
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -6,7 +7,7 @@ from copy import deepcopy
|
||||||
from enum import auto
|
from enum import auto
|
||||||
|
|
||||||
from sqlglot.errors import ParseError
|
from sqlglot.errors import ParseError
|
||||||
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
|
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
|
||||||
|
|
||||||
|
|
||||||
class _Expression(type):
|
class _Expression(type):
|
||||||
|
@ -350,7 +351,8 @@ class Expression(metaclass=_Expression):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fun (function): a function which takes a node as an argument and returns a
|
fun (function): a function which takes a node as an argument and returns a
|
||||||
new transformed node or the same node without modifications.
|
new transformed node or the same node without modifications. If the function
|
||||||
|
returns None, then the corresponding node will be removed from the syntax tree.
|
||||||
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
|
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
|
||||||
modified in place.
|
modified in place.
|
||||||
|
|
||||||
|
@ -360,9 +362,7 @@ class Expression(metaclass=_Expression):
|
||||||
node = self.copy() if copy else self
|
node = self.copy() if copy else self
|
||||||
new_node = fun(node, *args, **kwargs)
|
new_node = fun(node, *args, **kwargs)
|
||||||
|
|
||||||
if new_node is None:
|
if new_node is None or not isinstance(new_node, Expression):
|
||||||
raise ValueError("A transformed node cannot be None")
|
|
||||||
if not isinstance(new_node, Expression):
|
|
||||||
return new_node
|
return new_node
|
||||||
if new_node is not node:
|
if new_node is not node:
|
||||||
new_node.parent = node.parent
|
new_node.parent = node.parent
|
||||||
|
@ -843,10 +843,6 @@ class Ordered(Expression):
|
||||||
arg_types = {"this": True, "desc": True, "nulls_first": True}
|
arg_types = {"this": True, "desc": True, "nulls_first": True}
|
||||||
|
|
||||||
|
|
||||||
class Properties(Expression):
|
|
||||||
arg_types = {"expressions": True}
|
|
||||||
|
|
||||||
|
|
||||||
class Property(Expression):
|
class Property(Expression):
|
||||||
arg_types = {"this": True, "value": True}
|
arg_types = {"this": True, "value": True}
|
||||||
|
|
||||||
|
@ -891,6 +887,42 @@ class AnonymousProperty(Property):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Properties(Expression):
|
||||||
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
|
PROPERTY_KEY_MAPPING = {
|
||||||
|
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||||
|
"CHARACTER_SET": CharacterSetProperty,
|
||||||
|
"COLLATE": CollateProperty,
|
||||||
|
"COMMENT": SchemaCommentProperty,
|
||||||
|
"ENGINE": EngineProperty,
|
||||||
|
"FORMAT": FileFormatProperty,
|
||||||
|
"LOCATION": LocationProperty,
|
||||||
|
"PARTITIONED_BY": PartitionedByProperty,
|
||||||
|
"TABLE_FORMAT": TableFormatProperty,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, properties_dict):
|
||||||
|
expressions = []
|
||||||
|
for key, value in properties_dict.items():
|
||||||
|
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
|
||||||
|
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
|
||||||
|
return cls(expressions=expressions)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_value(value):
|
||||||
|
if isinstance(value, Expression):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return Literal.string(value)
|
||||||
|
if isinstance(value, numbers.Number):
|
||||||
|
return Literal.number(value)
|
||||||
|
if isinstance(value, list):
|
||||||
|
return Tuple(expressions=[_convert_value(v) for v in value])
|
||||||
|
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
class Qualify(Expression):
|
class Qualify(Expression):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression):
|
||||||
)
|
)
|
||||||
properties_expression = None
|
properties_expression = None
|
||||||
if properties:
|
if properties:
|
||||||
properties_str = " ".join(
|
properties_expression = Properties.from_dict(properties)
|
||||||
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
|
|
||||||
)
|
|
||||||
properties_expression = maybe_parse(
|
|
||||||
properties_str,
|
|
||||||
into=Properties,
|
|
||||||
dialect=dialect,
|
|
||||||
**opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Create(
|
return Create(
|
||||||
this=table_expression,
|
this=table_expression,
|
||||||
|
@ -1650,6 +1674,10 @@ class Star(Expression):
|
||||||
return "*"
|
return "*"
|
||||||
|
|
||||||
|
|
||||||
|
class Parameter(Expression):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Placeholder(Expression):
|
class Placeholder(Expression):
|
||||||
arg_types = {}
|
arg_types = {}
|
||||||
|
|
||||||
|
@ -1688,6 +1716,7 @@ class DataType(Expression):
|
||||||
INTERVAL = auto()
|
INTERVAL = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
|
TIMESTAMPLTZ = auto()
|
||||||
DATE = auto()
|
DATE = auto()
|
||||||
DATETIME = auto()
|
DATETIME = auto()
|
||||||
ARRAY = auto()
|
ARRAY = auto()
|
||||||
|
@ -1702,6 +1731,13 @@ class DataType(Expression):
|
||||||
SERIAL = auto()
|
SERIAL = auto()
|
||||||
SMALLSERIAL = auto()
|
SMALLSERIAL = auto()
|
||||||
BIGSERIAL = auto()
|
BIGSERIAL = auto()
|
||||||
|
XML = auto()
|
||||||
|
UNIQUEIDENTIFIER = auto()
|
||||||
|
MONEY = auto()
|
||||||
|
SMALLMONEY = auto()
|
||||||
|
ROWVERSION = auto()
|
||||||
|
IMAGE = auto()
|
||||||
|
SQL_VARIANT = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, dtype, **kwargs):
|
def build(cls, dtype, **kwargs):
|
||||||
|
@ -2976,7 +3012,7 @@ def replace_children(expression, fun):
|
||||||
else:
|
else:
|
||||||
new_child_nodes.append(cn)
|
new_child_nodes.append(cn)
|
||||||
|
|
||||||
expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
|
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
|
||||||
|
|
||||||
|
|
||||||
def column_table_names(expression):
|
def column_table_names(expression):
|
||||||
|
|
|
@ -748,6 +748,9 @@ class Generator:
|
||||||
def structkwarg_sql(self, expression):
|
def structkwarg_sql(self, expression):
|
||||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||||
|
|
||||||
|
def parameter_sql(self, expression):
|
||||||
|
return f"@{self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def placeholder_sql(self, *_):
|
def placeholder_sql(self, *_):
|
||||||
return "?"
|
return "?"
|
||||||
|
|
||||||
|
@ -903,7 +906,7 @@ class Generator:
|
||||||
return f"UNIQUE ({columns})"
|
return f"UNIQUE ({columns})"
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self, expression):
|
||||||
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
|
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
|
||||||
|
|
||||||
def in_sql(self, expression):
|
def in_sql(self, expression):
|
||||||
query = expression.args.get("query")
|
query = expression.args.get("query")
|
||||||
|
|
|
@ -81,6 +81,7 @@ class Parser:
|
||||||
TokenType.INTERVAL,
|
TokenType.INTERVAL,
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
|
TokenType.TIMESTAMPLTZ,
|
||||||
TokenType.DATETIME,
|
TokenType.DATETIME,
|
||||||
TokenType.DATE,
|
TokenType.DATE,
|
||||||
TokenType.DECIMAL,
|
TokenType.DECIMAL,
|
||||||
|
@ -92,6 +93,13 @@ class Parser:
|
||||||
TokenType.SERIAL,
|
TokenType.SERIAL,
|
||||||
TokenType.SMALLSERIAL,
|
TokenType.SMALLSERIAL,
|
||||||
TokenType.BIGSERIAL,
|
TokenType.BIGSERIAL,
|
||||||
|
TokenType.XML,
|
||||||
|
TokenType.UNIQUEIDENTIFIER,
|
||||||
|
TokenType.MONEY,
|
||||||
|
TokenType.SMALLMONEY,
|
||||||
|
TokenType.ROWVERSION,
|
||||||
|
TokenType.IMAGE,
|
||||||
|
TokenType.SQL_VARIANT,
|
||||||
*NESTED_TYPE_TOKENS,
|
*NESTED_TYPE_TOKENS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,6 +241,7 @@ class Parser:
|
||||||
TIMESTAMPS = {
|
TIMESTAMPS = {
|
||||||
TokenType.TIMESTAMP,
|
TokenType.TIMESTAMP,
|
||||||
TokenType.TIMESTAMPTZ,
|
TokenType.TIMESTAMPTZ,
|
||||||
|
TokenType.TIMESTAMPLTZ,
|
||||||
}
|
}
|
||||||
|
|
||||||
SET_OPERATIONS = {
|
SET_OPERATIONS = {
|
||||||
|
@ -315,6 +324,7 @@ class Parser:
|
||||||
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
|
||||||
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
|
||||||
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
|
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
|
||||||
|
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
|
||||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
||||||
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
||||||
TokenType.INTRODUCER: lambda self, token: self.expression(
|
TokenType.INTRODUCER: lambda self, token: self.expression(
|
||||||
|
@ -1497,12 +1507,19 @@ class Parser:
|
||||||
|
|
||||||
if type_token in self.TIMESTAMPS:
|
if type_token in self.TIMESTAMPS:
|
||||||
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
|
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
|
||||||
self._match(TokenType.WITHOUT_TIME_ZONE)
|
|
||||||
if tz:
|
if tz:
|
||||||
return exp.DataType(
|
return exp.DataType(
|
||||||
this=exp.DataType.Type.TIMESTAMPTZ,
|
this=exp.DataType.Type.TIMESTAMPTZ,
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
)
|
)
|
||||||
|
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
|
||||||
|
if ltz:
|
||||||
|
return exp.DataType(
|
||||||
|
this=exp.DataType.Type.TIMESTAMPLTZ,
|
||||||
|
expressions=expressions,
|
||||||
|
)
|
||||||
|
self._match(TokenType.WITHOUT_TIME_ZONE)
|
||||||
|
|
||||||
return exp.DataType(
|
return exp.DataType(
|
||||||
this=exp.DataType.Type.TIMESTAMP,
|
this=exp.DataType.Type.TIMESTAMP,
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
|
@ -1845,8 +1862,11 @@ class Parser:
|
||||||
def _parse_extract(self):
|
def _parse_extract(self):
|
||||||
this = self._parse_var() or self._parse_type()
|
this = self._parse_var() or self._parse_type()
|
||||||
|
|
||||||
if not self._match(TokenType.FROM):
|
if self._match(TokenType.FROM):
|
||||||
self.raise_error("Expected FROM after EXTRACT", self._prev)
|
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
||||||
|
|
||||||
|
if not self._match(TokenType.COMMA):
|
||||||
|
self.raise_error("Expected FROM or comma after EXTRACT", self._prev)
|
||||||
|
|
||||||
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ class TokenType(AutoName):
|
||||||
LR_ARROW = auto()
|
LR_ARROW = auto()
|
||||||
ANNOTATION = auto()
|
ANNOTATION = auto()
|
||||||
DOLLAR = auto()
|
DOLLAR = auto()
|
||||||
|
PARAMETER = auto()
|
||||||
|
|
||||||
SPACE = auto()
|
SPACE = auto()
|
||||||
BREAK = auto()
|
BREAK = auto()
|
||||||
|
@ -75,6 +76,7 @@ class TokenType(AutoName):
|
||||||
JSON = auto()
|
JSON = auto()
|
||||||
TIMESTAMP = auto()
|
TIMESTAMP = auto()
|
||||||
TIMESTAMPTZ = auto()
|
TIMESTAMPTZ = auto()
|
||||||
|
TIMESTAMPLTZ = auto()
|
||||||
DATETIME = auto()
|
DATETIME = auto()
|
||||||
DATE = auto()
|
DATE = auto()
|
||||||
UUID = auto()
|
UUID = auto()
|
||||||
|
@ -86,6 +88,13 @@ class TokenType(AutoName):
|
||||||
SERIAL = auto()
|
SERIAL = auto()
|
||||||
SMALLSERIAL = auto()
|
SMALLSERIAL = auto()
|
||||||
BIGSERIAL = auto()
|
BIGSERIAL = auto()
|
||||||
|
XML = auto()
|
||||||
|
UNIQUEIDENTIFIER = auto()
|
||||||
|
MONEY = auto()
|
||||||
|
SMALLMONEY = auto()
|
||||||
|
ROWVERSION = auto()
|
||||||
|
IMAGE = auto()
|
||||||
|
SQL_VARIANT = auto()
|
||||||
|
|
||||||
# keywords
|
# keywords
|
||||||
ADD_FILE = auto()
|
ADD_FILE = auto()
|
||||||
|
@ -247,6 +256,7 @@ class TokenType(AutoName):
|
||||||
WINDOW = auto()
|
WINDOW = auto()
|
||||||
WITH = auto()
|
WITH = auto()
|
||||||
WITH_TIME_ZONE = auto()
|
WITH_TIME_ZONE = auto()
|
||||||
|
WITH_LOCAL_TIME_ZONE = auto()
|
||||||
WITHIN_GROUP = auto()
|
WITHIN_GROUP = auto()
|
||||||
WITHOUT_TIME_ZONE = auto()
|
WITHOUT_TIME_ZONE = auto()
|
||||||
UNIQUE = auto()
|
UNIQUE = auto()
|
||||||
|
@ -340,7 +350,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"~": TokenType.TILDA,
|
"~": TokenType.TILDA,
|
||||||
"?": TokenType.PLACEHOLDER,
|
"?": TokenType.PLACEHOLDER,
|
||||||
"#": TokenType.ANNOTATION,
|
"#": TokenType.ANNOTATION,
|
||||||
"$": TokenType.DOLLAR,
|
"@": TokenType.PARAMETER,
|
||||||
# used for breaking a var like x'y' but nothing else
|
# used for breaking a var like x'y' but nothing else
|
||||||
# the token type doesn't matter
|
# the token type doesn't matter
|
||||||
"'": TokenType.QUOTE,
|
"'": TokenType.QUOTE,
|
||||||
|
@ -520,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"WHERE": TokenType.WHERE,
|
"WHERE": TokenType.WHERE,
|
||||||
"WITH": TokenType.WITH,
|
"WITH": TokenType.WITH,
|
||||||
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
|
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
|
||||||
|
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
|
||||||
"WITHIN GROUP": TokenType.WITHIN_GROUP,
|
"WITHIN GROUP": TokenType.WITHIN_GROUP,
|
||||||
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
|
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
|
||||||
"ARRAY": TokenType.ARRAY,
|
"ARRAY": TokenType.ARRAY,
|
||||||
|
@ -561,6 +572,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"BYTEA": TokenType.BINARY,
|
"BYTEA": TokenType.BINARY,
|
||||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||||
|
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||||
"DATE": TokenType.DATE,
|
"DATE": TokenType.DATE,
|
||||||
"DATETIME": TokenType.DATETIME,
|
"DATETIME": TokenType.DATETIME,
|
||||||
"UNIQUE": TokenType.UNIQUE,
|
"UNIQUE": TokenType.UNIQUE,
|
||||||
|
|
|
@ -228,6 +228,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
||||||
|
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
||||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -237,6 +238,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
|
||||||
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
|
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
|
||||||
|
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
|
||||||
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
|
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -246,6 +248,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRPTIME(x, '%y')",
|
"duckdb": "STRPTIME(x, '%y')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE(x, '%y')",
|
"presto": "DATE_PARSE(x, '%y')",
|
||||||
|
"redshift": "TO_TIMESTAMP(x, 'YY')",
|
||||||
"spark": "TO_TIMESTAMP(x, 'yy')",
|
"spark": "TO_TIMESTAMP(x, 'yy')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -287,6 +290,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
|
||||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
|
||||||
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
|
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
|
||||||
|
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -295,6 +299,7 @@ class TestDialect(Validator):
|
||||||
"duckdb": "CAST(x AS TEXT)",
|
"duckdb": "CAST(x AS TEXT)",
|
||||||
"hive": "CAST(x AS STRING)",
|
"hive": "CAST(x AS STRING)",
|
||||||
"presto": "CAST(x AS VARCHAR)",
|
"presto": "CAST(x AS VARCHAR)",
|
||||||
|
"redshift": "CAST(x AS TEXT)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -66,6 +66,9 @@ class TestDuckDB(Validator):
|
||||||
def test_duckdb(self):
|
def test_duckdb(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"LIST_VALUE(0, 1, 2)",
|
"LIST_VALUE(0, 1, 2)",
|
||||||
|
read={
|
||||||
|
"spark": "ARRAY(0, 1, 2)",
|
||||||
|
},
|
||||||
write={
|
write={
|
||||||
"bigquery": "[0, 1, 2]",
|
"bigquery": "[0, 1, 2]",
|
||||||
"duckdb": "LIST_VALUE(0, 1, 2)",
|
"duckdb": "LIST_VALUE(0, 1, 2)",
|
||||||
|
|
|
@ -131,7 +131,7 @@ class TestHive(Validator):
|
||||||
write={
|
write={
|
||||||
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
|
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
|
||||||
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
|
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
|
||||||
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
6
tests/dialects/test_oracle.py
Normal file
6
tests/dialects/test_oracle.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
class TestOracle(Validator):
|
||||||
|
def test_oracle(self):
|
||||||
|
self.validate_identity("SELECT * FROM V$SESSION")
|
|
@ -173,7 +173,7 @@ class TestPresto(Validator):
|
||||||
write={
|
write={
|
||||||
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
||||||
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
||||||
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -181,7 +181,7 @@ class TestPresto(Validator):
|
||||||
write={
|
write={
|
||||||
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
|
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
|
||||||
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
|
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
|
||||||
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
|
|
@ -175,3 +175,48 @@ class TestSnowflake(Validator):
|
||||||
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
"snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_timestamps(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST(a AS TIMESTAMP)",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a::TIMESTAMP_LTZ(9)",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ(9))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a::TIMESTAMPLTZ",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a::TIMESTAMP WITH LOCAL TIME ZONE",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_identity("SELECT EXTRACT(month FROM a)")
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT EXTRACT('month', a)",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT EXTRACT('month' FROM a)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT DATE_PART('month', a)",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT EXTRACT('month' FROM a)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT DATE_PART(month FROM a::DATETIME)",
|
||||||
|
write={
|
||||||
|
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -44,15 +44,7 @@ class TestSpark(Validator):
|
||||||
write={
|
write={
|
||||||
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
||||||
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
||||||
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
|
"spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
|
||||||
},
|
|
||||||
)
|
|
||||||
self.validate_all(
|
|
||||||
"CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
|
|
||||||
write={
|
|
||||||
"presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1",
|
|
||||||
"hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
|
|
||||||
"spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -86,7 +78,7 @@ COMMENT 'Test comment: blah'
|
||||||
PARTITIONED BY (
|
PARTITIONED BY (
|
||||||
date STRING
|
date STRING
|
||||||
)
|
)
|
||||||
STORED AS ICEBERG
|
USING ICEBERG
|
||||||
TBLPROPERTIES (
|
TBLPROPERTIES (
|
||||||
'x' = '1'
|
'x' = '1'
|
||||||
)""",
|
)""",
|
||||||
|
|
26
tests/dialects/test_tsql.py
Normal file
26
tests/dialects/test_tsql.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from tests.dialects.test_dialect import Validator
|
||||||
|
|
||||||
|
|
||||||
|
class TestTSQL(Validator):
|
||||||
|
dialect = "tsql"
|
||||||
|
|
||||||
|
def test_tsql(self):
|
||||||
|
self.validate_identity('SELECT "x"."y" FROM foo')
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
|
||||||
|
write={
|
||||||
|
"tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo',
|
||||||
|
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_types(self):
|
||||||
|
self.validate_identity("CAST(x AS XML)")
|
||||||
|
self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)")
|
||||||
|
self.validate_identity("CAST(x AS MONEY)")
|
||||||
|
self.validate_identity("CAST(x AS SMALLMONEY)")
|
||||||
|
self.validate_identity("CAST(x AS ROWVERSION)")
|
||||||
|
self.validate_identity("CAST(x AS IMAGE)")
|
||||||
|
self.validate_identity("CAST(x AS SQL_VARIANT)")
|
||||||
|
self.validate_identity("CAST(x AS BIT)")
|
|
@ -224,9 +224,6 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
|
self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)")
|
||||||
self.assertIs(actual_expression_2, expression)
|
self.assertIs(actual_expression_2, expression)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
parse_one("a").transform(lambda n: None)
|
|
||||||
|
|
||||||
def test_transform_no_infinite_recursion(self):
|
def test_transform_no_infinite_recursion(self):
|
||||||
expression = parse_one("a")
|
expression = parse_one("a")
|
||||||
|
|
||||||
|
@ -247,6 +244,35 @@ class TestExpressions(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x")
|
self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x")
|
||||||
|
|
||||||
|
def test_transform_node_removal(self):
|
||||||
|
expression = parse_one("SELECT a, b FROM x")
|
||||||
|
|
||||||
|
def remove_column_b(node):
|
||||||
|
if isinstance(node, exp.Column) and node.name == "b":
|
||||||
|
return None
|
||||||
|
return node
|
||||||
|
|
||||||
|
self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x")
|
||||||
|
self.assertEqual(expression.transform(lambda _: None), None)
|
||||||
|
|
||||||
|
expression = parse_one("CAST(x AS FLOAT)")
|
||||||
|
|
||||||
|
def remove_non_list_arg(node):
|
||||||
|
if isinstance(node, exp.DataType):
|
||||||
|
return None
|
||||||
|
return node
|
||||||
|
|
||||||
|
self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )")
|
||||||
|
|
||||||
|
expression = parse_one("SELECT a, b FROM x")
|
||||||
|
|
||||||
|
def remove_all_columns(node):
|
||||||
|
if isinstance(node, exp.Column):
|
||||||
|
return None
|
||||||
|
return node
|
||||||
|
|
||||||
|
self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x")
|
||||||
|
|
||||||
def test_replace(self):
|
def test_replace(self):
|
||||||
expression = parse_one("SELECT a, b FROM x")
|
expression = parse_one("SELECT a, b FROM x")
|
||||||
expression.find(exp.Column).replace(parse_one("c"))
|
expression.find(exp.Column).replace(parse_one("c"))
|
||||||
|
|
|
@ -114,6 +114,9 @@ class TestParser(unittest.TestCase):
|
||||||
with self.assertRaises(ParseError):
|
with self.assertRaises(ParseError):
|
||||||
parse_one("SELECT FROM x ORDER BY")
|
parse_one("SELECT FROM x ORDER BY")
|
||||||
|
|
||||||
|
def test_parameter(self):
|
||||||
|
self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")
|
||||||
|
|
||||||
def test_annotations(self):
|
def test_annotations(self):
|
||||||
expression = parse_one(
|
expression = parse_one(
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue