1
0
Fork 0

Merging upstream version 26.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:59:10 +01:00
parent e2fd836612
commit 63d24513e5
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
65 changed files with 45416 additions and 44542 deletions

View file

@ -1,6 +1,28 @@
Changelog
=========
## [v26.0.0] - 2024-12-10
### :boom: BREAKING CHANGES
- due to [`1d3c9aa`](https://github.com/tobymao/sqlglot/commit/1d3c9aa604c7bf60166a0e5587f1a8d88b89bea6) - Transpile support for bitor/bit_or snowflake function *(PR [#4486](https://github.com/tobymao/sqlglot/pull/4486) by [@ankur334](https://github.com/ankur334))*:
Transpile support for bitor/bit_or snowflake function (#4486)
- due to [`ab10851`](https://github.com/tobymao/sqlglot/commit/ab108518c53173ddf71ac1dfd9e45df6ac621b81) - Preserve roundtrips of DATETIME/DATETIME2 *(PR [#4491](https://github.com/tobymao/sqlglot/pull/4491) by [@VaggelisD](https://github.com/VaggelisD))*:
Preserve roundtrips of DATETIME/DATETIME2 (#4491)
### :sparkles: New Features
- [`1d3c9aa`](https://github.com/tobymao/sqlglot/commit/1d3c9aa604c7bf60166a0e5587f1a8d88b89bea6) - **snowflake**: Transpile support for bitor/bit_or snowflake function *(PR [#4486](https://github.com/tobymao/sqlglot/pull/4486) by [@ankur334](https://github.com/ankur334))*
- [`822aea0`](https://github.com/tobymao/sqlglot/commit/822aea0826f09fa773193004acb2af99e495fddd) - **snowflake**: Support for inline FOREIGN KEY *(PR [#4493](https://github.com/tobymao/sqlglot/pull/4493) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *addresses issue [#4489](https://github.com/tobymao/sqlglot/issues/4489) opened by [@kylekarpack](https://github.com/kylekarpack)*
### :bug: Bug Fixes
- [`ab10851`](https://github.com/tobymao/sqlglot/commit/ab108518c53173ddf71ac1dfd9e45df6ac621b81) - **tsql**: Preserve roundtrips of DATETIME/DATETIME2 *(PR [#4491](https://github.com/tobymao/sqlglot/pull/4491) by [@VaggelisD](https://github.com/VaggelisD))*
- [`43975e4`](https://github.com/tobymao/sqlglot/commit/43975e4b7abcd640cd5a0f91aea1fbda8dd893cb) - **duckdb**: Allow escape strings similar to Postgres *(PR [#4497](https://github.com/tobymao/sqlglot/pull/4497) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *fixes issue [#4496](https://github.com/tobymao/sqlglot/issues/4496) opened by [@LennartH](https://github.com/LennartH)*
## [v25.34.1] - 2024-12-10
### :boom: BREAKING CHANGES
- due to [`f70f124`](https://github.com/tobymao/sqlglot/commit/f70f12408fbaf021dd105f2eac957b9e6fac045d) - transpile MySQL FORMAT to DuckDB *(PR [#4488](https://github.com/tobymao/sqlglot/pull/4488) by [@georgesittas](https://github.com/georgesittas))*:
@ -5448,3 +5470,4 @@ Changelog
[v25.33.0]: https://github.com/tobymao/sqlglot/compare/v25.32.1...v25.33.0
[v25.34.0]: https://github.com/tobymao/sqlglot/compare/v25.33.0...v25.34.0
[v25.34.1]: https://github.com/tobymao/sqlglot/compare/v25.34.0...v25.34.1
[v26.0.0]: https://github.com/tobymao/sqlglot/compare/v25.34.1...v26.0.0

File diff suppressed because one or more lines are too long

View file

@ -76,8 +76,8 @@
</span><span id="L-12"><a href="#L-12"><span class="linenos">12</span></a><span class="n">__version_tuple__</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-14"><a href="#L-14"><span class="linenos">14</span></a>
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;25.34.1&#39;</span>
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">25</span><span class="p">,</span> <span class="mi">34</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;26.0.0&#39;</span>
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">26</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
</span></pre></div>
@ -97,7 +97,7 @@
<section id="version">
<div class="attr variable">
<span class="name">version</span><span class="annotation">: str</span> =
<span class="default_value">&#39;25.34.1&#39;</span>
<span class="default_value">&#39;26.0.0&#39;</span>
</div>
@ -109,7 +109,7 @@
<section id="version_tuple">
<div class="attr variable">
<span class="name">version_tuple</span><span class="annotation">: object</span> =
<span class="default_value">(25, 34, 1)</span>
<span class="default_value">(26, 0, 0)</span>
</div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -1920,7 +1920,7 @@ belong to some totally-ordered set.</p>
<section id="DATE_UNITS">
<div class="attr variable">
<span class="name">DATE_UNITS</span> =
<span class="default_value">{&#39;quarter&#39;, &#39;year&#39;, &#39;year_month&#39;, &#39;month&#39;, &#39;week&#39;, &#39;day&#39;}</span>
<span class="default_value">{&#39;year_month&#39;, &#39;year&#39;, &#39;quarter&#39;, &#39;week&#39;, &#39;day&#39;, &#39;month&#39;}</span>
</div>

View file

@ -640,7 +640,7 @@
<div class="attr variable">
<span class="name">ALL_JSON_PATH_PARTS</span> =
<input id="ALL_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;}</span>
<label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;}</span>
</div>

File diff suppressed because one or more lines are too long

View file

@ -581,7 +581,7 @@ queries if it would result in multiple table selects in a single query:</p>
<div class="attr variable">
<span class="name">UNMERGABLE_ARGS</span> =
<input id="UNMERGABLE_ARGS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;locks&#39;, &#39;sample&#39;, &#39;options&#39;, &#39;match&#39;, &#39;laterals&#39;, &#39;into&#39;, &#39;limit&#39;, &#39;windows&#39;, &#39;kind&#39;, &#39;operation_modifiers&#39;, &#39;pivots&#39;, &#39;distinct&#39;, &#39;prewhere&#39;, &#39;having&#39;, &#39;connect&#39;, &#39;offset&#39;, &#39;qualify&#39;, &#39;cluster&#39;, &#39;group&#39;, &#39;settings&#39;, &#39;with&#39;, &#39;format&#39;, &#39;sort&#39;, &#39;distribute&#39;}</span>
<label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;distinct&#39;, &#39;locks&#39;, &#39;operation_modifiers&#39;, &#39;offset&#39;, &#39;format&#39;, &#39;prewhere&#39;, &#39;pivots&#39;, &#39;group&#39;, &#39;kind&#39;, &#39;limit&#39;, &#39;sample&#39;, &#39;connect&#39;, &#39;laterals&#39;, &#39;sort&#39;, &#39;distribute&#39;, &#39;qualify&#39;, &#39;having&#39;, &#39;into&#39;, &#39;cluster&#39;, &#39;settings&#39;, &#39;options&#39;, &#39;windows&#39;, &#39;with&#39;, &#39;match&#39;}</span>
</div>

View file

@ -3231,7 +3231,7 @@ prefix are statically known.</p>
<div class="attr variable">
<span class="name">DATETRUNC_COMPARISONS</span> =
<input id="DATETRUNC_COMPARISONS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;}</span>
<label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;}</span>
</div>
@ -3315,7 +3315,7 @@ prefix are statically known.</p>
<section id="JOINS">
<div class="attr variable">
<span class="name">JOINS</span> =
<span class="default_value">{(&#39;&#39;, &#39;&#39;), (&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;RIGHT&#39;, &#39;&#39;), (&#39;&#39;, &#39;INNER&#39;)}</span>
<span class="default_value">{(&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;&#39;, &#39;INNER&#39;), (&#39;RIGHT&#39;, &#39;&#39;), (&#39;&#39;, &#39;&#39;)}</span>
</div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -565,6 +565,8 @@ class ClickHouse(Dialect):
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
"""
index = self._index
this = self._parse_id_var()
self._match(TokenType.COLON)
kind = self._parse_types(check_func=False, allow_identifiers=False) or (
@ -572,12 +574,32 @@ class ClickHouse(Dialect):
)
if not kind:
self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
self._retreat(index)
return None
elif not self._match(TokenType.R_BRACE):
self.raise_error("Expecting }")
return self.expression(exp.Placeholder, this=this, kind=kind)
def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
l_brace = self._match(TokenType.L_BRACE, advance=False)
bracket = super()._parse_bracket(this)
if l_brace and isinstance(bracket, exp.Struct):
varmap = exp.VarMap(keys=exp.Array(), values=exp.Array())
for expression in bracket.expressions:
if not isinstance(expression, exp.PropertyEQ):
break
varmap.args["keys"].append("expressions", exp.Literal.string(expression.name))
varmap.args["values"].append("expressions", expression.expression)
return varmap
return bracket
def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
this = super()._parse_in(this)
this.set("is_global", is_global)

View file

@ -15,7 +15,13 @@ from sqlglot.time import TIMEZONES, format_time, subsecond_precision
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_DIFF = t.Union[
exp.DateAdd,
exp.DateDiff,
exp.DateSub,
exp.TsOrDsAdd,
exp.TsOrDsDiff,
]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]

View file

@ -205,6 +205,7 @@ class MySQL(Dialect):
"MEDIUMINT": TokenType.MEDIUMINT,
"MEMBER OF": TokenType.MEMBER_OF,
"SEPARATOR": TokenType.SEPARATOR,
"SERIAL": TokenType.SERIAL,
"START": TokenType.BEGIN,
"SIGNED": TokenType.BIGINT,
"SIGNED INTEGER": TokenType.BIGINT,

View file

@ -687,6 +687,14 @@ class Postgres(Dialect):
values = self.expressions(expression, key="values", flat=True)
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
if (
expression.is_type(exp.DataType.Type.DOUBLE, exp.DataType.Type.FLOAT)
and expression.expressions
):
# Postgres doesn't support precision for REAL and DOUBLE PRECISION types
return f"FLOAT({self.expressions(expression, flat=True)})"
return super().datatype_sql(expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:

View file

@ -32,7 +32,7 @@ from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot._typing import E, B
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
@ -107,11 +107,14 @@ def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
return _builder
def _build_bitor(args: t.List) -> exp.BitwiseOr | exp.Anonymous:
def _build_bitwise(expr_type: t.Type[B], name: str) -> t.Callable[[t.List], B | exp.Anonymous]:
def _builder(args: t.List) -> B | exp.Anonymous:
if len(args) == 3:
return exp.Anonymous(this="BITOR", expressions=args)
return exp.Anonymous(this=name, expressions=args)
return binary_from_function(exp.BitwiseOr)(args)
return binary_from_function(expr_type)(args)
return _builder
# https://docs.snowflake.com/en/sql-reference/functions/div0
@ -398,11 +401,15 @@ class Snowflake(Dialect):
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
step=seq_get(args, 2),
),
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BITOR": _build_bitor,
"BIT_OR": _build_bitor,
"BOOLXOR": binary_from_function(exp.Xor),
"BITXOR": _build_bitwise(exp.BitwiseXor, "BITXOR"),
"BIT_XOR": _build_bitwise(exp.BitwiseXor, "BITXOR"),
"BITOR": _build_bitwise(exp.BitwiseOr, "BITOR"),
"BIT_OR": _build_bitwise(exp.BitwiseOr, "BITOR"),
"BITSHIFTLEFT": _build_bitwise(exp.BitwiseLeftShift, "BITSHIFTLEFT"),
"BIT_SHIFTLEFT": _build_bitwise(exp.BitwiseLeftShift, "BIT_SHIFTLEFT"),
"BITSHIFTRIGHT": _build_bitwise(exp.BitwiseRightShift, "BITSHIFTRIGHT"),
"BIT_SHIFTRIGHT": _build_bitwise(exp.BitwiseRightShift, "BIT_SHIFTRIGHT"),
"BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"),
"DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": _build_date_time_add(exp.DateAdd),
@ -885,8 +892,10 @@ class Snowflake(Dialect):
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
exp.BitwiseOr: rename_func("BITOR"),
exp.BitwiseXor: rename_func("BITXOR"),
exp.BitwiseLeftShift: rename_func("BITSHIFTLEFT"),
exp.BitwiseRightShift: rename_func("BITSHIFTRIGHT"),
exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
@ -1088,7 +1097,11 @@ class Snowflake(Dialect):
else:
unnest_alias = exp.TableAlias(this="_u", columns=columns)
explode = f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
table_input = self.sql(expression.expressions[0])
if not table_input.startswith("INPUT =>"):
table_input = f"INPUT => {table_input}"
explode = f"TABLE(FLATTEN({table_input}))"
alias = self.sql(unnest_alias)
alias = f" AS {alias}" if alias else ""
return f"{explode}{alias}"
@ -1202,3 +1215,12 @@ class Snowflake(Dialect):
this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
return self.func("TO_CHAR", this, self.format_time(expression))
def datesub_sql(self, expression: exp.DateSub) -> str:
value = expression.expression
if value:
value.replace(value * (-1))
else:
self.unsupported("DateSub cannot be transpiled if the subtracted count is unknown")
return date_delta_sql("DATEADD")(self, expression)

View file

@ -14,11 +14,18 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
class StarRocks(MySQL):
STRICT_JSON_PATH_SYNTAX = False
class Tokenizer(MySQL.Tokenizer):
KEYWORDS = {
**MySQL.Tokenizer.KEYWORDS,
"LARGEINT": TokenType.INT128,
}
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
@ -34,7 +41,9 @@ class StarRocks(MySQL):
PROPERTY_PARSERS = {
**MySQL.Parser.PROPERTY_PARSERS,
"UNIQUE": lambda self: self._parse_composite_key_property(exp.UniqueKeyProperty),
"PROPERTIES": lambda self: self._parse_wrapped_properties(),
"PARTITION BY": lambda self: self._parse_partition_by_opt_range(),
}
def _parse_create(self) -> exp.Create | exp.Command:
@ -70,6 +79,32 @@ class StarRocks(MySQL):
return unnest
def _parse_partitioning_granularity_dynamic(self) -> exp.PartitionByRangePropertyDynamic:
self._match_text_seq("START")
start = self._parse_wrapped(self._parse_string)
self._match_text_seq("END")
end = self._parse_wrapped(self._parse_string)
self._match_text_seq("EVERY")
every = self._parse_wrapped(lambda: self._parse_interval() or self._parse_number())
return self.expression(
exp.PartitionByRangePropertyDynamic, start=start, end=end, every=every
)
def _parse_partition_by_opt_range(
self,
) -> exp.PartitionedByProperty | exp.PartitionByRangeProperty:
if self._match_text_seq("RANGE"):
partition_expressions = self._parse_wrapped_id_vars()
create_expressions = self._parse_wrapped_csv(
self._parse_partitioning_granularity_dynamic
)
return self.expression(
exp.PartitionByRangeProperty,
partition_expressions=partition_expressions,
create_expressions=create_expressions,
)
return super()._parse_partitioned_by()
class Generator(MySQL.Generator):
EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
@ -81,6 +116,7 @@ class StarRocks(MySQL):
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.INT128: "LARGEINT",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
@ -89,6 +125,8 @@ class StarRocks(MySQL):
PROPERTIES_LOCATION = {
**MySQL.Generator.PROPERTIES_LOCATION,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.UniqueKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionByRangeProperty: exp.Properties.Location.POST_SCHEMA,
}
TRANSFORMS = {

View file

@ -232,7 +232,7 @@ def _build_date_delta(
if start_date and start_date.is_number:
# Numeric types are valid DATETIME values
if start_date.is_int:
adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this))
adds = DEFAULT_START_DATE + datetime.timedelta(days=start_date.to_py())
start_date = exp.Literal.string(adds.strftime("%F"))
else:
# We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.

View file

@ -2855,6 +2855,21 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}
# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/
class PartitionByRangeProperty(Property):
arg_types = {"partition_expressions": True, "create_expressions": True}
# https://docs.starrocks.io/docs/table_design/data_distribution/#range-partitioning
class PartitionByRangePropertyDynamic(Expression):
arg_types = {"this": False, "start": True, "end": True, "every": True}
# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/
class UniqueKeyProperty(Property):
arg_types = {"expressions": True}
# https://www.postgresql.org/docs/current/sql-createtable.html
class PartitionBoundSpec(Expression):
# this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...)
@ -6665,6 +6680,11 @@ class Week(Func):
arg_types = {"this": True, "mode": False}
class XMLElement(Func):
_sql_names = ["XMLELEMENT"]
arg_types = {"this": True, "expressions": False}
class XMLTable(Func):
arg_types = {"this": True, "passing": False, "columns": False, "by_ref": False}

View file

@ -3759,6 +3759,10 @@ class Generator(metaclass=_Generator):
def duplicatekeyproperty_sql(self, expression: exp.DuplicateKeyProperty) -> str:
return f"DUPLICATE KEY ({self.expressions(expression, flat=True)})"
# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/
def uniquekeyproperty_sql(self, expression: exp.UniqueKeyProperty) -> str:
return f"UNIQUE KEY ({self.expressions(expression, flat=True)})"
# https://docs.starrocks.io/docs/sql-reference/sql-statements/data-definition/CREATE_TABLE/#distribution_desc
def distributedbyproperty_sql(self, expression: exp.DistributedByProperty) -> str:
expressions = self.expressions(expression, flat=True)
@ -4612,3 +4616,24 @@ class Generator(metaclass=_Generator):
include = f"{include} AS {alias}"
return include
def xmlelement_sql(self, expression: exp.XMLElement) -> str:
name = f"NAME {self.sql(expression, 'this')}"
return self.func("XMLELEMENT", name, *expression.expressions)
def partitionbyrangeproperty_sql(self, expression: exp.PartitionByRangeProperty) -> str:
partitions = self.expressions(expression, "partition_expressions")
create = self.expressions(expression, "create_expressions")
return f"PARTITION BY RANGE {self.wrap(partitions)} {self.wrap(create)}"
def partitionbyrangepropertydynamic_sql(
self, expression: exp.PartitionByRangePropertyDynamic
) -> str:
start = self.sql(expression, "start")
end = self.sql(expression, "end")
every = expression.args["every"]
if isinstance(every, exp.Interval) and every.this.is_string:
every.this.replace(exp.Literal.number(every.name))
return f"START {self.wrap(start)} END {self.wrap(end)} EVERY {self.wrap(self.sql(every))}"

View file

@ -254,6 +254,27 @@ def to_node(
if dt.comments and dt.comments[0].startswith("source: ")
}
pivots = scope.pivots
pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
if pivot:
# For each aggregation function, the pivot creates a new column for each field in category
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
# in the lineage, so lookup the pivot column name by index and map that with the columns used
# in the aggregation.
#
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
pivot_columns = pivot.args["columns"]
pivot_aggs_count = len(pivot.expressions)
pivot_column_mapping = {}
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
for c in source_columns:
table = c.table
source = scope.sources.get(table)
@ -265,6 +286,7 @@ def to_node(
elif source.scope_type == ScopeType.CTE:
selected_node, _ = scope.selected_sources.get(table, (None, None))
reference_node_name = selected_node.name if selected_node else None
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
@ -276,10 +298,45 @@ def to_node(
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
elif pivot and pivot.alias_or_name == c.table:
downstream_columns = []
column_name = c.name
if any(column_name == pivot_column.name for pivot_column in pivot_columns):
downstream_columns.extend(pivot_column_mapping[column_name])
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
# The column is not in the pivot, so it must be an implicit column of the
# pivoted source -- adapt column to be from the implicit pivoted source.
downstream_columns.append(exp.column(c.this, table=pivot.parent.this))
for downstream_column in downstream_columns:
table = downstream_column.table
source = scope.sources.get(table)
if isinstance(source, Scope):
to_node(
downstream_column.name,
scope=source,
scope_name=table,
dialect=dialect,
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
else:
source = source or exp.Placeholder()
node.downstream.append(
Node(
name=downstream_column.sql(comments=False),
source=source,
expression=source,
)
)
else:
# The source is not a scope and the column is not in any pivot - we've reached the end
# of the line. At this point, if a source is not found it means this column's lineage
# is unknown. This can happen if the definition of a source used in a query is not
# passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(
Node(name=c.sql(comments=False), source=source, expression=source)

View file

@ -66,6 +66,7 @@ def qualify_columns(
_expand_alias_refs(
scope,
resolver,
dialect,
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
)
@ -73,9 +74,9 @@ def qualify_columns(
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
_expand_alias_refs(scope, resolver, dialect)
if not isinstance(scope.expression, exp.UDTF):
if isinstance(scope.expression, exp.Select):
if expand_stars:
_expand_stars(
scope,
@ -236,7 +237,15 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
return column_tables
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
def _expand_alias_refs(
scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False
) -> None:
"""
Expand references to aliases.
Example:
SELECT y.foo AS bar, bar * 2 AS baz FROM y
=> SELECT y.foo AS bar, y.foo * 2 AS baz FROM y
"""
expression = scope.expression
if not isinstance(expression, exp.Select):
@ -309,6 +318,12 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
# Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else)
# https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes
if dialect == "snowflake":
for join in expression.args.get("joins") or []:
replace_columns(join)
scope.clear_cache()
@ -883,10 +898,22 @@ class Resolver:
for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
pseudocolumns = self._get_source_pseudocolumns(name)
if pseudocolumns:
columns = list(columns)
columns.extend(c for c in pseudocolumns if c not in columns)
self._get_source_columns_cache[cache_key] = columns
return self._get_source_columns_cache[cache_key]
def _get_source_pseudocolumns(self, name: str) -> t.Sequence[str]:
if self.schema.dialect == "snowflake" and self.scope.expression.args.get("connect"):
# When there is a CONNECT BY clause, there is only one table being scanned
# See: https://docs.snowflake.com/en/sql-reference/constructs/connect-by
return ["LEVEL"]
return []
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:
self._source_columns = {

View file

@ -1328,13 +1328,17 @@ def _flat_simplify(expression, simplifier, root=True):
return expression
def gen(expression: t.Any) -> str:
def gen(expression: t.Any, comments: bool = False) -> str:
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
Args:
expression: the expression to convert into a SQL string.
comments: whether to include the expression's comments.
"""
return Gen().gen(expression)
return Gen().gen(expression, comments=comments)
class Gen:
@ -1342,7 +1346,7 @@ class Gen:
self.stack = []
self.sqls = []
def gen(self, expression: exp.Expression) -> str:
def gen(self, expression: exp.Expression, comments: bool = False) -> str:
self.stack = [expression]
self.sqls.clear()
@ -1350,6 +1354,9 @@ class Gen:
node = self.stack.pop()
if isinstance(node, exp.Expression):
if comments and node.comments:
self.stack.append(f" /*{','.join(node.comments)}*/")
exp_handler_name = f"{node.key}_sql"
if hasattr(self, exp_handler_name):

View file

@ -920,7 +920,7 @@ class Parser(metaclass=_Parser):
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DISTRIBUTED": lambda self: self._parse_distributed_property(),
"DUPLICATE": lambda self: self._parse_duplicate(),
"DUPLICATE": lambda self: self._parse_composite_key_property(exp.DuplicateKeyProperty),
"DYNAMIC": lambda self: self.expression(exp.DynamicProperty),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
@ -1143,6 +1143,11 @@ class Parser(metaclass=_Parser):
"TRIM": lambda self: self._parse_trim(),
"TRY_CAST": lambda self: self._parse_cast(False, safe=True),
"TRY_CONVERT": lambda self: self._parse_convert(False, safe=True),
"XMLELEMENT": lambda self: self.expression(
exp.XMLElement,
this=self._match_text_seq("NAME") and self._parse_id_var(),
expressions=self._match(TokenType.COMMA) and self._parse_csv(self._parse_expression),
),
}
QUERY_MODIFIER_PARSERS = {
@ -2203,10 +2208,10 @@ class Parser(metaclass=_Parser):
order=self._parse_order(),
)
def _parse_duplicate(self) -> exp.DuplicateKeyProperty:
def _parse_composite_key_property(self, expr_type: t.Type[E]) -> E:
self._match_text_seq("KEY")
expressions = self._parse_wrapped_csv(self._parse_id_var, optional=False)
return self.expression(exp.DuplicateKeyProperty, expressions=expressions)
expressions = self._parse_wrapped_id_vars()
return self.expression(expr_type, expressions=expressions)
def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match_text_seq("(", "SYSTEM_VERSIONING"):
@ -4615,14 +4620,14 @@ class Parser(metaclass=_Parser):
this = exp.Literal.string(this.to_py())
elif this and this.is_string:
parts = exp.INTERVAL_STRING_RE.findall(this.name)
if len(parts) == 1:
if unit:
if parts and unit:
# Unconsume the eagerly-parsed unit, since the real unit was part of the string
unit = None
self._retreat(self._index - 1)
if len(parts) == 1:
this = exp.Literal.string(parts[0][0])
unit = self.expression(exp.Var, this=parts[0][1].upper())
if self.INTERVAL_SPANS and self._match_text_seq("TO"):
unit = self.expression(
exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True)
@ -5351,18 +5356,21 @@ class Parser(metaclass=_Parser):
functions = self.FUNCTIONS
function = functions.get(upper)
known_function = function and not anonymous
alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS
alias = not known_function or upper in self.FUNCTIONS_WITH_ALIASED_ARGS
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if alias:
if alias and known_function:
args = self._kv_to_prop_eq(args)
if function and not anonymous:
if "dialect" in function.__code__.co_varnames:
func = function(args, dialect=self.dialect)
if known_function:
func_builder = t.cast(t.Callable, function)
if "dialect" in func_builder.__code__.co_varnames:
func = func_builder(args, dialect=self.dialect)
else:
func = function(args)
func = func_builder(args)
func = self.validate_expression(func, args)
if self.dialect.PRESERVE_ORIGINAL_NAMES:
@ -6730,7 +6738,9 @@ class Parser(metaclass=_Parser):
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_set_operations(
self._parse_expression() if alias else self._parse_assignment()
self._parse_alias(self._parse_assignment(), explicit=True)
if alias
else self._parse_assignment()
)
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:

View file

@ -201,11 +201,13 @@ class Step:
aggregate.add_dependency(step)
step = aggregate
else:
aggregate = None
order = expression.args.get("order")
if order:
if isinstance(step, Aggregate):
if aggregate and isinstance(step, Aggregate):
for i, ordered in enumerate(order.expressions):
if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))

View file

@ -315,6 +315,14 @@ LANGUAGE js AS
"SELECT CAST(1 AS INT64)",
)
self.validate_all(
"SELECT DATE_SUB(DATE '2008-12-25', INTERVAL 5 DAY)",
write={
"bigquery": "SELECT DATE_SUB(CAST('2008-12-25' AS DATE), INTERVAL '5' DAY)",
"duckdb": "SELECT CAST('2008-12-25' AS DATE) - INTERVAL '5' DAY",
"snowflake": "SELECT DATEADD(DAY, '5' * -1, CAST('2008-12-25' AS DATE))",
},
)
self.validate_all(
"EDIT_DISTANCE(col1, col2, max_distance => 3)",
write={

View file

@ -154,6 +154,10 @@ class TestClickhouse(Validator):
self.validate_identity(
"CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)"
)
self.validate_identity(
"INSERT INTO tab VALUES ({'key1': 1, 'key2': 10}), ({'key1': 2, 'key2': 20}), ({'key1': 3, 'key2': 30})",
"INSERT INTO tab VALUES (map('key1', 1, 'key2', 10)), (map('key1', 2, 'key2', 20)), (map('key1', 3, 'key2', 30))",
)
self.validate_identity(
"SELECT (toUInt8('1') + toUInt8('2')) IS NOT NULL",
"SELECT NOT ((toUInt8('1') + toUInt8('2')) IS NULL)",

View file

@ -18,6 +18,7 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)")
self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
self.validate_identity("CREATE TABLE temp (id SERIAL PRIMARY KEY)")
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10")
self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))")

View file

@ -71,6 +71,9 @@ class TestPostgres(Validator):
self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True)
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
"SELECT * FROM t WHERE some_column >= CURRENT_DATE + INTERVAL '1 day 1 hour' AND some_another_column IS TRUE"
)
self.validate_identity(
"""UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)"""
)
@ -1289,3 +1292,17 @@ CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(JSON_EXTRACT_PATH(tbox, 'boxes') AS JSON)) A
"clickhouse": UnsupportedError,
},
)
def test_xmlelement(self):
self.validate_identity("SELECT XMLELEMENT(NAME foo)")
self.validate_identity("SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES('xyz' AS bar))")
self.validate_identity("SELECT XMLELEMENT(NAME test, XMLATTRIBUTES(a, b)) FROM test")
self.validate_identity(
"SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES(CURRENT_DATE AS bar), 'cont', 'ent')"
)
self.validate_identity(
"""SELECT XMLELEMENT(NAME "foo$bar", XMLATTRIBUTES('xyz' AS "a&b"))"""
)
self.validate_identity(
"SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES('xyz' AS bar), XMLELEMENT(NAME abc), XMLCOMMENT('test'), XMLELEMENT(NAME xyz))"
)

View file

@ -320,6 +320,7 @@ class TestRedshift(Validator):
)
def test_identity(self):
self.validate_identity("SELECT CAST(value AS FLOAT(8))")
self.validate_identity("1 div", "1 AS div")
self.validate_identity("LISTAGG(DISTINCT foo, ', ')")
self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1")

View file

@ -21,27 +21,6 @@ class TestSnowflake(Validator):
expr.selects[0].assert_is(exp.AggFunc)
self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
self.assertEqual(
exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql(
"snowflake", pretty=True
),
"""SELECT
IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y"
FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
GREATEST(ARRAY_SIZE(x)) - 1
) + 1))) AS _u(seq, key, path, index, pos, this)
CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this)
WHERE
_u.pos = _u_2.pos_2
OR (
_u.pos > (
ARRAY_SIZE(x) - 1
) AND _u_2.pos_2 = (
ARRAY_SIZE(x) - 1
)
)""",
)
self.validate_identity("exclude := [foo]")
self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))")
self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias")
@ -976,12 +955,15 @@ WHERE
"snowflake": "EDITDISTANCE(col1, col2, 3)",
},
)
self.validate_identity("SELECT BITOR(a, b) FROM table")
self.validate_identity("SELECT BIT_OR(a, b) FROM table", "SELECT BITOR(a, b) FROM table")
# Test BITOR with three arguments, padding on the left
self.validate_identity("SELECT BITOR(a, b, 'LEFT') FROM table_name")
self.validate_identity("SELECT BITOR(a, b)")
self.validate_identity("SELECT BIT_OR(a, b)", "SELECT BITOR(a, b)")
self.validate_identity("SELECT BITOR(a, b, 'LEFT')")
self.validate_identity("SELECT BITXOR(a, b, 'LEFT')")
self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)")
self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')")
self.validate_identity("SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)")
self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
def test_null_treatment(self):
self.validate_all(
@ -1600,6 +1582,27 @@ WHERE
)
def test_flatten(self):
self.assertEqual(
exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql(
"snowflake", pretty=True
),
"""SELECT
IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y"
FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
GREATEST(ARRAY_SIZE(x)) - 1
) + 1))) AS _u(seq, key, path, index, pos, this)
CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this)
WHERE
_u.pos = _u_2.pos_2
OR (
_u.pos > (
ARRAY_SIZE(x) - 1
) AND _u_2.pos_2 = (
ARRAY_SIZE(x) - 1
)
)""",
)
self.validate_all(
"""
select
@ -1624,6 +1627,75 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS
},
pretty=True,
)
self.validate_all(
"""
SELECT
uc.user_id,
uc.start_ts AS ts,
CASE
WHEN uc.start_ts::DATE >= '2023-01-01' AND uc.country_code IN ('US') AND uc.user_id NOT IN (
SELECT DISTINCT
_id
FROM
users,
LATERAL FLATTEN(INPUT => PARSE_JSON(flags)) datasource
WHERE datasource.value:name = 'something'
)
THEN 'Sample1'
ELSE 'Sample2'
END AS entity
FROM user_countries AS uc
LEFT JOIN (
SELECT user_id, MAX(IFF(service_entity IS NULL,1,0)) AS le_null
FROM accepted_user_agreements
GROUP BY 1
) AS aua
ON uc.user_id = aua.user_id
""",
write={
"snowflake": """SELECT
uc.user_id,
uc.start_ts AS ts,
CASE
WHEN CAST(uc.start_ts AS DATE) >= '2023-01-01'
AND uc.country_code IN ('US')
AND uc.user_id <> ALL (
SELECT DISTINCT
_id
FROM users, LATERAL IFF(_u.pos = _u_2.pos_2, _u_2.entity, NULL) AS datasource(SEQ, KEY, PATH, INDEX, VALUE, THIS)
WHERE
GET_PATH(datasource.value, 'name') = 'something'
)
THEN 'Sample1'
ELSE 'Sample2'
END AS entity
FROM user_countries AS uc
LEFT JOIN (
SELECT
user_id,
MAX(IFF(service_entity IS NULL, 1, 0)) AS le_null
FROM accepted_user_agreements
GROUP BY
1
) AS aua
ON uc.user_id = aua.user_id
CROSS JOIN TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
GREATEST(ARRAY_SIZE(INPUT => PARSE_JSON(flags))) - 1
) + 1))) AS _u(seq, key, path, index, pos, this)
CROSS JOIN TABLE(FLATTEN(INPUT => PARSE_JSON(flags))) AS _u_2(seq, key, path, pos_2, entity, this)
WHERE
_u.pos = _u_2.pos_2
OR (
_u.pos > (
ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1
)
AND _u_2.pos_2 = (
ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1
)
)""",
},
pretty=True,
)
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(

View file

@ -18,6 +18,8 @@ class TestStarrocks(Validator):
"DISTRIBUTED BY HASH (col1) PROPERTIES ('replication_num'='1')",
"PRIMARY KEY (col1) DISTRIBUTED BY HASH (col1)",
"DUPLICATE KEY (col1, col2) DISTRIBUTED BY HASH (col1)",
"UNIQUE KEY (col1, col2) PARTITION BY RANGE (col1) (START ('2024-01-01') END ('2024-01-31') EVERY (INTERVAL 1 DAY)) DISTRIBUTED BY HASH (col1)",
"UNIQUE KEY (col1, col2) PARTITION BY RANGE (col1, col2) (START ('1') END ('10') EVERY (1), START ('10') END ('100') EVERY (10)) DISTRIBUTED BY HASH (col1)",
]
for properties in ddl_sqls:
@ -31,6 +33,9 @@ class TestStarrocks(Validator):
self.validate_identity(
"CREATE TABLE foo (col0 DECIMAL(9, 1), col1 DECIMAL32(9, 1), col2 DECIMAL64(18, 10), col3 DECIMAL128(38, 10)) DISTRIBUTED BY HASH (col1) BUCKETS 1"
)
self.validate_identity(
"CREATE TABLE foo (col1 LARGEINT) DISTRIBUTED BY HASH (col1) BUCKETS 1"
)
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")

View file

@ -1579,6 +1579,11 @@ WHERE
},
)
self.validate_identity(
"SELECT DATEADD(DAY, DATEDIFF(DAY, -3, GETDATE()), '08:00:00')",
"SELECT DATEADD(DAY, DATEDIFF(DAY, CAST('1899-12-29' AS DATETIME2), CAST(GETDATE() AS DATETIME2)), '08:00:00')",
)
def test_lateral_subquery(self):
self.validate_all(
"SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)",

View file

@ -190,6 +190,10 @@ SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0,
SELECT SOME_UDF(data).* FROM t;
SELECT SOME_UDF(t.data).* FROM t AS t;
# execute: false
SELECT p.* FROM p UNION ALL SELECT p2.* FROM p2;
SELECT p.* FROM p AS p UNION ALL SELECT p2.* FROM p2 AS p2;
# execute: false
# allow_partial_qualification: true
# validate_qualify_columns: false
@ -201,6 +205,30 @@ SELECT x.a + 1 AS i, missing_column AS missing_column FROM x AS x;
SELECT s, arr1, arr2 FROM arrays_test LEFT ARRAY JOIN arr1, arrays_test.arr2;
SELECT arrays_test.s AS s, arrays_test.arr1 AS arr1, arrays_test.arr2 AS arr2 FROM arrays_test AS arrays_test LEFT ARRAY JOIN arrays_test.arr1, arrays_test.arr2;
# execute: false
# dialect: snowflake
WITH employees AS (
SELECT *
FROM (VALUES ('President', 1, NULL),
('Vice President Engineering', 10, 1),
('Programmer', 100, 10),
('QA Engineer', 101, 10),
('Vice President HR', 20, 1),
('Health Insurance Analyst', 200, 20)
) AS t(title, employee_ID, manager_ID)
)
SELECT
employee_ID,
manager_ID,
title,
level
FROM employees
START WITH title = 'President'
CONNECT BY manager_ID = PRIOR employee_id
ORDER BY
employee_ID NULLS LAST;
WITH EMPLOYEES AS (SELECT T.TITLE AS TITLE, T.EMPLOYEE_ID AS EMPLOYEE_ID, T.MANAGER_ID AS MANAGER_ID FROM (VALUES ('President', 1, NULL), ('Vice President Engineering', 10, 1), ('Programmer', 100, 10), ('QA Engineer', 101, 10), ('Vice President HR', 20, 1), ('Health Insurance Analyst', 200, 20)) AS T(TITLE, EMPLOYEE_ID, MANAGER_ID)) SELECT EMPLOYEES.EMPLOYEE_ID AS EMPLOYEE_ID, EMPLOYEES.MANAGER_ID AS MANAGER_ID, EMPLOYEES.TITLE AS TITLE, EMPLOYEES.LEVEL AS LEVEL FROM EMPLOYEES AS EMPLOYEES START WITH EMPLOYEES.TITLE = 'President' CONNECT BY EMPLOYEES.MANAGER_ID = PRIOR EMPLOYEES.EMPLOYEE_ID ORDER BY EMPLOYEE_ID;
--------------------------------------
-- Derived tables
--------------------------------------
@ -727,3 +755,30 @@ SELECT y.b AS b FROM ((SELECT x.a AS a FROM x AS x) AS _q_0 INNER JOIN y AS y ON
SELECT a, c FROM x TABLESAMPLE SYSTEM (10 ROWS) CROSS JOIN y TABLESAMPLE SYSTEM (10 ROWS);
SELECT x.a AS a, y.c AS c FROM x AS x TABLESAMPLE SYSTEM (10 ROWS) CROSS JOIN y AS y TABLESAMPLE SYSTEM (10 ROWS);
--------------------------------------
-- Snowflake allows column alias to be used in almost all clauses
--------------------------------------
# title: Snowflake column alias in JOIN
# dialect: snowflake
# execute: false
SELECT x.a AS foo FROM x JOIN y ON foo = y.b;
SELECT X.A AS FOO FROM X AS X JOIN Y AS Y ON X.A = Y.B;
# title: Snowflake column alias in QUALIFY
# dialect: snowflake
# execute: false
SELECT x.a AS foo FROM x QUALIFY foo = 1;
SELECT X.A AS FOO FROM X AS X QUALIFY X.A = 1;
# title: Snowflake column alias in GROUP BY
# dialect: snowflake
# execute: false
SELECT x.a AS foo FROM x GROUP BY foo = 1;
SELECT X.A AS FOO FROM X AS X GROUP BY X.A = 1;
# title: Snowflake column alias in WHERE
# dialect: snowflake
# execute: false
SELECT x.a AS foo FROM x WHERE foo = 1;
SELECT X.A AS FOO FROM X AS X WHERE X.A = 1;

View file

@ -495,3 +495,84 @@ class TestLineage(unittest.TestCase):
self.assertEqual(len(node.downstream), 1)
self.assertEqual(len(node.downstream[0].downstream), 1)
self.assertEqual(node.downstream[0].downstream[0].name, "t1.x")
def test_pivot_without_alias(self) -> None:
sql = """
SELECT
a as other_a
FROM (select value,category from sample_data)
PIVOT (
sum(value)
FOR category IN ('a', 'b')
);
"""
node = lineage("other_a", sql)
self.assertEqual(node.downstream[0].name, "_q_0.value")
self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value")
def test_pivot_with_alias(self) -> None:
sql = """
SELECT
cat_a_s as other_as
FROM sample_data
PIVOT (
sum(value) as s, max(price)
FOR category IN ('a' as cat_a, 'b')
)
"""
node = lineage("other_as", sql)
self.assertEqual(len(node.downstream), 1)
self.assertEqual(node.downstream[0].name, "sample_data.value")
def test_pivot_with_cte(self) -> None:
sql = """
WITH t as (
SELECT
a as other_a
FROM sample_data
PIVOT (
sum(value)
FOR category IN ('a', 'b')
)
)
select other_a from t
"""
node = lineage("other_a", sql)
self.assertEqual(node.downstream[0].name, "t.other_a")
self.assertEqual(node.downstream[0].reference_node_name, "t")
self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value")
def test_pivot_with_implicit_column_of_pivoted_source(self) -> None:
sql = """
SELECT empid
FROM quarterly_sales
PIVOT(SUM(amount) FOR quarter IN (
'2023_Q1',
'2023_Q2',
'2023_Q3'))
ORDER BY empid;
"""
node = lineage("empid", sql)
self.assertEqual(node.downstream[0].name, "quarterly_sales.empid")
def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None:
sql = """
WITH t as (
SELECT empid
FROM quarterly_sales
PIVOT(SUM(amount) FOR quarter IN (
'2023_Q1',
'2023_Q2',
'2023_Q3'))
)
select empid from t
"""
node = lineage("empid", sql)
self.assertEqual(node.downstream[0].name, "t.empid")
self.assertEqual(node.downstream[0].reference_node_name, "t")
self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid")

View file

@ -551,6 +551,10 @@ class TestOptimizer(unittest.TestCase):
SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True
""".strip(),
)
self.assertEqual(
optimizer.simplify.gen(parse_one("select item_id /* description */"), comments=True),
"SELECT :expressions,item_id /* description */",
)
def test_unnest_subqueries(self):
self.check_file("unnest_subqueries", optimizer.unnest_subqueries.unnest_subqueries)