1
0
Fork 0

Merging upstream version 6.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:31:47 +01:00
parent 0822fbed3a
commit 9bc11b290e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
22 changed files with 312 additions and 45 deletions

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.1.1"
__version__ = "6.2.0"
pretty = False

View file

@ -14,3 +14,4 @@ from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL

View file

@ -27,6 +27,7 @@ class Dialects(str, Enum):
STARROCKS = "starrocks"
TABLEAU = "tableau"
TRINO = "trino"
TSQL = "tsql"
class _Dialect(type):
@ -53,7 +54,6 @@ class _Dialect(type):
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
@ -95,7 +95,6 @@ class Dialect(metaclass=_Dialect):
tokenizer_class = None
parser_class = None
generator_class = None
tokenizer = None
@classmethod
def get_or_raise(cls, dialect):
@ -138,6 +137,12 @@ class Dialect(metaclass=_Dialect):
def transpile(self, code, **opts):
return self.generate(self.parse(code), **opts)
@property
def tokenizer(self):
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts):
return self.parser_class(
**{
@ -170,7 +175,15 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
def _rename(self, expression):
args = (
self.expressions(expression, flat=True)
if isinstance(expression, exp.Func) and expression.is_var_len_args
else csv(*[self.sql(e) for e in expression.args.values()])
)
return f"{name}({args})"
return _rename
def approx_count_distinct_sql(self, expression):

View file

@ -108,7 +108,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})",
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),

View file

@ -106,6 +106,11 @@ class Snowflake(Dialect):
"TO_TIMESTAMP": _snowflake_to_timestamp,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_extract(),
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@ -118,10 +123,20 @@ class Snowflake(Dialect):
class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
"$": TokenType.DOLLAR, # needed to break for quotes
}
KEYWORDS = {
**Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
}
class Generator(Generator):
@ -132,6 +147,11 @@ class Snowflake(Dialect):
exp.UnixToTime: _unix_to_time,
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -82,6 +82,7 @@ class Spark(Hive):
TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),

38
sqlglot/dialects/tsql.py Normal file
View file

@ -0,0 +1,38 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.tokens import Tokenizer, TokenType
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
}

View file

@ -1,4 +1,5 @@
import inspect
import numbers
import re
import sys
from collections import deque
@ -6,7 +7,7 @@ from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
class _Expression(type):
@ -350,7 +351,8 @@ class Expression(metaclass=_Expression):
Args:
fun (function): a function which takes a node as an argument and returns a
new transformed node or the same node without modifications.
new transformed node or the same node without modifications. If the function
returns None, then the corresponding node will be removed from the syntax tree.
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
modified in place.
@ -360,9 +362,7 @@ class Expression(metaclass=_Expression):
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
if new_node is None:
raise ValueError("A transformed node cannot be None")
if not isinstance(new_node, Expression):
if new_node is None or not isinstance(new_node, Expression):
return new_node
if new_node is not node:
new_node.parent = node.parent
@ -843,10 +843,6 @@ class Ordered(Expression):
arg_types = {"this": True, "desc": True, "nulls_first": True}
class Properties(Expression):
arg_types = {"expressions": True}
class Property(Expression):
arg_types = {"this": True, "value": True}
@ -891,6 +887,42 @@ class AnonymousProperty(Property):
pass
class Properties(Expression):
arg_types = {"expressions": True}
PROPERTY_KEY_MAPPING = {
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER_SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"ENGINE": EngineProperty,
"FORMAT": FileFormatProperty,
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@classmethod
def from_dict(cls, properties_dict):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if isinstance(value, Expression):
return value
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[_convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression):
pass
@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression):
)
properties_expression = None
if properties:
properties_str = " ".join(
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
into=Properties,
dialect=dialect,
**opts,
)
properties_expression = Properties.from_dict(properties)
return Create(
this=table_expression,
@ -1650,6 +1674,10 @@ class Star(Expression):
return "*"
class Parameter(Expression):
pass
class Placeholder(Expression):
arg_types = {}
@ -1688,6 +1716,7 @@ class DataType(Expression):
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATE = auto()
DATETIME = auto()
ARRAY = auto()
@ -1702,6 +1731,13 @@ class DataType(Expression):
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@ -2976,7 +3012,7 @@ def replace_children(expression, fun):
else:
new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
def column_table_names(expression):

View file

@ -748,6 +748,9 @@ class Generator:
def structkwarg_sql(self, expression):
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def placeholder_sql(self, *_):
return "?"
@ -903,7 +906,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")

View file

@ -81,6 +81,7 @@ class Parser:
TokenType.INTERVAL,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
TokenType.DATETIME,
TokenType.DATE,
TokenType.DECIMAL,
@ -92,6 +93,13 @@ class Parser:
TokenType.SERIAL,
TokenType.SMALLSERIAL,
TokenType.BIGSERIAL,
TokenType.XML,
TokenType.UNIQUEIDENTIFIER,
TokenType.MONEY,
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.SQL_VARIANT,
*NESTED_TYPE_TOKENS,
}
@ -233,6 +241,7 @@ class Parser:
TIMESTAMPS = {
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
}
SET_OPERATIONS = {
@ -315,6 +324,7 @@ class Parser:
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression(
@ -1497,12 +1507,19 @@ class Parser:
if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
self._match(TokenType.WITHOUT_TIME_ZONE)
if tz:
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
if ltz:
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
self._match(TokenType.WITHOUT_TIME_ZONE)
return exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
@ -1845,8 +1862,11 @@ class Parser:
def _parse_extract(self):
this = self._parse_var() or self._parse_type()
if not self._match(TokenType.FROM):
self.raise_error("Expected FROM after EXTRACT", self._prev)
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
if not self._match(TokenType.COMMA):
self.raise_error("Expected FROM or comma after EXTRACT", self._prev)
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())

View file

@ -41,6 +41,7 @@ class TokenType(AutoName):
LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
SPACE = auto()
BREAK = auto()
@ -75,6 +76,7 @@ class TokenType(AutoName):
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
DATETIME = auto()
DATE = auto()
UUID = auto()
@ -86,6 +88,13 @@ class TokenType(AutoName):
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
# keywords
ADD_FILE = auto()
@ -247,6 +256,7 @@ class TokenType(AutoName):
WINDOW = auto()
WITH = auto()
WITH_TIME_ZONE = auto()
WITH_LOCAL_TIME_ZONE = auto()
WITHIN_GROUP = auto()
WITHOUT_TIME_ZONE = auto()
UNIQUE = auto()
@ -340,7 +350,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION,
"$": TokenType.DOLLAR,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
@ -520,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer):
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"ARRAY": TokenType.ARRAY,
@ -561,6 +572,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.BINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
"DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE,