1
0
Fork 0

Merging upstream version 25.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:39:30 +01:00
parent 7ab180cac9
commit 3b7539dcad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
79 changed files with 28803 additions and 24929 deletions

View file

@ -1,6 +1,16 @@
Changelog
=========
## [v25.0.3] - 2024-06-06
### :sparkles: New Features
- [`97f8d1a`](https://github.com/tobymao/sqlglot/commit/97f8d1a05801bcd7fd237dac0470c232d3106ca4) - add materialize dialect *(PR [#3577](https://github.com/tobymao/sqlglot/pull/3577) by [@bobbyiliev](https://github.com/bobbyiliev))*
- [`bde5a8d`](https://github.com/tobymao/sqlglot/commit/bde5a8de346125704f757ed6a2de444905fe146e) - add risingwave dialect *(PR [#3598](https://github.com/tobymao/sqlglot/pull/3598) by [@neverchanje](https://github.com/neverchanje))*
### :recycle: Refactors
- [`5140817`](https://github.com/tobymao/sqlglot/commit/51408172ce940b6ab0ad783d98e632d972da6a0a) - **risingwave**: clean up initial implementation of RisingWave *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`f920014`](https://github.com/tobymao/sqlglot/commit/f920014709c2d3ccb7ec18fb622ecd6b6ee0afcd) - **materialize**: clean up initial implementation of Materialize *(PR [#3608](https://github.com/tobymao/sqlglot/pull/3608) by [@georgesittas](https://github.com/georgesittas))*
## [v25.0.2] - 2024-06-05
### :sparkles: New Features
- [`472058d`](https://github.com/tobymao/sqlglot/commit/472058daccf8dc2a7f7f4b7082309a06802017a5) - **bigquery**: add support for GAP_FILL function *(commit by [@georgesittas](https://github.com/georgesittas))*
@ -3859,3 +3869,4 @@ Changelog
[v24.1.2]: https://github.com/tobymao/sqlglot/compare/v24.1.1...v24.1.2
[v25.0.0]: https://github.com/tobymao/sqlglot/compare/v24.1.2...v25.0.0
[v25.0.2]: https://github.com/tobymao/sqlglot/compare/v25.0.1...v25.0.2
[v25.0.3]: https://github.com/tobymao/sqlglot/compare/v25.0.2...v25.0.3

View file

@ -86,7 +86,7 @@ I tried to parse invalid SQL and it worked, even though it should raise an error
What happened to sqlglot.dataframe?
* The PySpark dataframe api was moved to a standalone library called [sqlframe](https://github.com/eakmanrq/sqlframe) in v24. It now allows you to run queries as opposed to just generate SQL.
* The PySpark dataframe api was moved to a standalone library called [SQLFrame](https://github.com/eakmanrq/sqlframe) in v24. It now allows you to run queries as opposed to just generate SQL.
## Examples
@ -505,7 +505,7 @@ See also: [Writing a Python SQL engine from scratch](https://github.com/tobymao/
* [Querybook](https://github.com/pinterest/querybook)
* [Quokka](https://github.com/marsupialtail/quokka)
* [Splink](https://github.com/moj-analytical-services/splink)
* [sqlframe](https://github.com/eakmanrq/sqlframe)
* [SQLFrame](https://github.com/eakmanrq/sqlframe)
## Documentation

File diff suppressed because one or more lines are too long

View file

@ -76,8 +76,8 @@
</span><span id="L-12"><a href="#L-12"><span class="linenos">12</span></a><span class="n">__version_tuple__</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-14"><a href="#L-14"><span class="linenos">14</span></a>
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;25.0.2&#39;</span>
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">25</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;25.0.3&#39;</span>
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">25</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</span></pre></div>
@ -97,7 +97,7 @@
<section id="version">
<div class="attr variable">
<span class="name">version</span><span class="annotation">: str</span> =
<span class="default_value">&#39;25.0.2&#39;</span>
<span class="default_value">&#39;25.0.3&#39;</span>
</div>
@ -109,7 +109,7 @@
<section id="version_tuple">
<div class="attr variable">
<span class="name">version_tuple</span><span class="annotation">: object</span> =
<span class="default_value">(25, 0, 2)</span>
<span class="default_value">(25, 0, 3)</span>
</div>

View file

@ -43,12 +43,14 @@
<li><a href="dialects/drill.html">drill</a></li>
<li><a href="dialects/duckdb.html">duckdb</a></li>
<li><a href="dialects/hive.html">hive</a></li>
<li><a href="dialects/materialize.html">materialize</a></li>
<li><a href="dialects/mysql.html">mysql</a></li>
<li><a href="dialects/oracle.html">oracle</a></li>
<li><a href="dialects/postgres.html">postgres</a></li>
<li><a href="dialects/presto.html">presto</a></li>
<li><a href="dialects/prql.html">prql</a></li>
<li><a href="dialects/redshift.html">redshift</a></li>
<li><a href="dialects/risingwave.html">risingwave</a></li>
<li><a href="dialects/snowflake.html">snowflake</a></li>
<li><a href="dialects/spark.html">spark</a></li>
<li><a href="dialects/spark2.html">spark2</a></li>
@ -212,21 +214,23 @@ dialect implementations in order to understand how their various components can
</span><span id="L-70"><a href="#L-70"><span class="linenos">70</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.drill</span> <span class="kn">import</span> <span class="n">Drill</span>
</span><span id="L-71"><a href="#L-71"><span class="linenos">71</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.duckdb</span> <span class="kn">import</span> <span class="n">DuckDB</span>
</span><span id="L-72"><a href="#L-72"><span class="linenos">72</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.hive</span> <span class="kn">import</span> <span class="n">Hive</span>
</span><span id="L-73"><a href="#L-73"><span class="linenos">73</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.mysql</span> <span class="kn">import</span> <span class="n">MySQL</span>
</span><span id="L-74"><a href="#L-74"><span class="linenos">74</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.oracle</span> <span class="kn">import</span> <span class="n">Oracle</span>
</span><span id="L-75"><a href="#L-75"><span class="linenos">75</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.postgres</span> <span class="kn">import</span> <span class="n">Postgres</span>
</span><span id="L-76"><a href="#L-76"><span class="linenos">76</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.presto</span> <span class="kn">import</span> <span class="n">Presto</span>
</span><span id="L-77"><a href="#L-77"><span class="linenos">77</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.prql</span> <span class="kn">import</span> <span class="n">PRQL</span>
</span><span id="L-78"><a href="#L-78"><span class="linenos">78</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.redshift</span> <span class="kn">import</span> <span class="n">Redshift</span>
</span><span id="L-79"><a href="#L-79"><span class="linenos">79</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.snowflake</span> <span class="kn">import</span> <span class="n">Snowflake</span>
</span><span id="L-80"><a href="#L-80"><span class="linenos">80</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark</span> <span class="kn">import</span> <span class="n">Spark</span>
</span><span id="L-81"><a href="#L-81"><span class="linenos">81</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark2</span> <span class="kn">import</span> <span class="n">Spark2</span>
</span><span id="L-82"><a href="#L-82"><span class="linenos">82</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.sqlite</span> <span class="kn">import</span> <span class="n">SQLite</span>
</span><span id="L-83"><a href="#L-83"><span class="linenos">83</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.starrocks</span> <span class="kn">import</span> <span class="n">StarRocks</span>
</span><span id="L-84"><a href="#L-84"><span class="linenos">84</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tableau</span> <span class="kn">import</span> <span class="n">Tableau</span>
</span><span id="L-85"><a href="#L-85"><span class="linenos">85</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.teradata</span> <span class="kn">import</span> <span class="n">Teradata</span>
</span><span id="L-86"><a href="#L-86"><span class="linenos">86</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.trino</span> <span class="kn">import</span> <span class="n">Trino</span>
</span><span id="L-87"><a href="#L-87"><span class="linenos">87</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tsql</span> <span class="kn">import</span> <span class="n">TSQL</span>
</span><span id="L-73"><a href="#L-73"><span class="linenos">73</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.materialize</span> <span class="kn">import</span> <span class="n">Materialize</span>
</span><span id="L-74"><a href="#L-74"><span class="linenos">74</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.mysql</span> <span class="kn">import</span> <span class="n">MySQL</span>
</span><span id="L-75"><a href="#L-75"><span class="linenos">75</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.oracle</span> <span class="kn">import</span> <span class="n">Oracle</span>
</span><span id="L-76"><a href="#L-76"><span class="linenos">76</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.postgres</span> <span class="kn">import</span> <span class="n">Postgres</span>
</span><span id="L-77"><a href="#L-77"><span class="linenos">77</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.presto</span> <span class="kn">import</span> <span class="n">Presto</span>
</span><span id="L-78"><a href="#L-78"><span class="linenos">78</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.prql</span> <span class="kn">import</span> <span class="n">PRQL</span>
</span><span id="L-79"><a href="#L-79"><span class="linenos">79</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.redshift</span> <span class="kn">import</span> <span class="n">Redshift</span>
</span><span id="L-80"><a href="#L-80"><span class="linenos">80</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.risingwave</span> <span class="kn">import</span> <span class="n">RisingWave</span>
</span><span id="L-81"><a href="#L-81"><span class="linenos">81</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.snowflake</span> <span class="kn">import</span> <span class="n">Snowflake</span>
</span><span id="L-82"><a href="#L-82"><span class="linenos">82</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark</span> <span class="kn">import</span> <span class="n">Spark</span>
</span><span id="L-83"><a href="#L-83"><span class="linenos">83</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark2</span> <span class="kn">import</span> <span class="n">Spark2</span>
</span><span id="L-84"><a href="#L-84"><span class="linenos">84</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.sqlite</span> <span class="kn">import</span> <span class="n">SQLite</span>
</span><span id="L-85"><a href="#L-85"><span class="linenos">85</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.starrocks</span> <span class="kn">import</span> <span class="n">StarRocks</span>
</span><span id="L-86"><a href="#L-86"><span class="linenos">86</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tableau</span> <span class="kn">import</span> <span class="n">Tableau</span>
</span><span id="L-87"><a href="#L-87"><span class="linenos">87</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.teradata</span> <span class="kn">import</span> <span class="n">Teradata</span>
</span><span id="L-88"><a href="#L-88"><span class="linenos">88</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.trino</span> <span class="kn">import</span> <span class="n">Trino</span>
</span><span id="L-89"><a href="#L-89"><span class="linenos">89</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tsql</span> <span class="kn">import</span> <span class="n">TSQL</span>
</span></pre></div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -10073,7 +10073,7 @@ Default: True</li>
<div class="attr variable">
<span class="name">SUPPORTED_JSON_PATH_PARTS</span> =
<input id="Generator.SUPPORTED_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="Generator.SUPPORTED_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;}</span>
<label class="view-value-button pdoc-button" for="Generator.SUPPORTED_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;}</span>
</div>
@ -10355,7 +10355,7 @@ Default: True</li>
<div id="Generator.PARAMETERIZABLE_TEXT_TYPES" class="classattr">
<div class="attr variable">
<span class="name">PARAMETERIZABLE_TEXT_TYPES</span> =
<span class="default_value">{&lt;Type.CHAR: &#39;CHAR&#39;&gt;, &lt;Type.NVARCHAR: &#39;NVARCHAR&#39;&gt;, &lt;Type.VARCHAR: &#39;VARCHAR&#39;&gt;, &lt;Type.NCHAR: &#39;NCHAR&#39;&gt;}</span>
<span class="default_value">{&lt;Type.CHAR: &#39;CHAR&#39;&gt;, &lt;Type.NVARCHAR: &#39;NVARCHAR&#39;&gt;, &lt;Type.NCHAR: &#39;NCHAR&#39;&gt;, &lt;Type.VARCHAR: &#39;VARCHAR&#39;&gt;}</span>
</div>

View file

@ -1893,7 +1893,7 @@ belong to some totally-ordered set.</p>
<section id="DATE_UNITS">
<div class="attr variable">
<span class="name">DATE_UNITS</span> =
<span class="default_value">{&#39;year&#39;, &#39;month&#39;, &#39;quarter&#39;, &#39;day&#39;, &#39;year_month&#39;, &#39;week&#39;}</span>
<span class="default_value">{&#39;quarter&#39;, &#39;month&#39;, &#39;year&#39;, &#39;year_month&#39;, &#39;week&#39;, &#39;day&#39;}</span>
</div>

View file

@ -577,7 +577,7 @@
<div class="attr variable">
<span class="name">ALL_JSON_PATH_PARTS</span> =
<input id="ALL_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;}</span>
<label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;}</span>
</div>

File diff suppressed because one or more lines are too long

View file

@ -586,7 +586,7 @@ queries if it would result in multiple table selects in a single query:</p>
<div class="attr variable">
<span class="name">UNMERGABLE_ARGS</span> =
<input id="UNMERGABLE_ARGS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;sample&#39;, &#39;prewhere&#39;, &#39;offset&#39;, &#39;group&#39;, &#39;with&#39;, &#39;laterals&#39;, &#39;kind&#39;, &#39;distinct&#39;, &#39;having&#39;, &#39;sort&#39;, &#39;cluster&#39;, &#39;limit&#39;, &#39;format&#39;, &#39;locks&#39;, &#39;distribute&#39;, &#39;settings&#39;, &#39;match&#39;, &#39;connect&#39;, &#39;qualify&#39;, &#39;options&#39;, &#39;windows&#39;, &#39;into&#39;, &#39;pivots&#39;}</span>
<label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;prewhere&#39;, &#39;locks&#39;, &#39;having&#39;, &#39;distinct&#39;, &#39;into&#39;, &#39;limit&#39;, &#39;match&#39;, &#39;options&#39;, &#39;cluster&#39;, &#39;connect&#39;, &#39;laterals&#39;, &#39;windows&#39;, &#39;qualify&#39;, &#39;offset&#39;, &#39;pivots&#39;, &#39;sort&#39;, &#39;group&#39;, &#39;format&#39;, &#39;with&#39;, &#39;distribute&#39;, &#39;sample&#39;, &#39;kind&#39;, &#39;settings&#39;}</span>
</div>

View file

@ -3220,7 +3220,7 @@ prefix are statically known.</p>
<div class="attr variable">
<span class="name">DATETRUNC_COMPARISONS</span> =
<input id="DATETRUNC_COMPARISONS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;}</span>
<label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;}</span>
</div>
@ -3300,7 +3300,7 @@ prefix are statically known.</p>
<section id="JOINS">
<div class="attr variable">
<span class="name">JOINS</span> =
<span class="default_value">{(&#39;RIGHT&#39;, &#39;&#39;), (&#39;&#39;, &#39;INNER&#39;), (&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;&#39;, &#39;&#39;)}</span>
<span class="default_value">{(&#39;RIGHT&#39;, &#39;&#39;), (&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;&#39;, &#39;INNER&#39;), (&#39;&#39;, &#39;&#39;)}</span>
</div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
date_add_interval_sql,
datestrtodate_sql,
build_formatted_time,
build_timestamp_from_parts,
filter_array_using_unnest,
if_sql,
inline_array_unless_query,
@ -22,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval,
regexp_replace_sql,
rename_func,
sha256_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
unit_to_var,
@ -321,6 +323,7 @@ class BigQuery(Dialect):
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATETIME": build_timestamp_from_parts,
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
@ -637,9 +640,7 @@ class BigQuery(Dialect):
]
),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
@ -649,6 +650,7 @@ class BigQuery(Dialect):
),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeFromParts: rename_func("TIME"),
exp.TimestampFromParts: rename_func("DATETIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),

View file

@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
build_json_extract_path,
rename_func,
sha256_sql,
var_map_sql,
timestamptrunc_sql,
)
@ -758,9 +759,7 @@ class ClickHouse(Dialect):
exp.MD5Digest: rename_func("MD5"),
exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Variance: rename_func("varSamp"),

View file

@ -169,6 +169,7 @@ class _Dialect(type):
if enum not in ("", "athena", "presto", "trino"):
klass.generator_class.TRY_SUPPORTED = False
klass.generator_class.SUPPORTS_UESCAPE = False
if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
@ -177,6 +178,14 @@ class _Dialect(type):
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
if enum not in ("", "doris", "mysql"):
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@ -220,6 +229,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_SEMI_ANTI_JOIN = True
"""Whether `SEMI` or `ANTI` joins are supported."""
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""
NORMALIZE_FUNCTIONS: bool | str = "upper"
"""
Determines how function names are going to be normalized.
@ -1178,3 +1190,16 @@ def build_default_decimal_type(
return exp.DataType.build(f"DECIMAL({params})")
return _builder
def build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)

View file

@ -207,7 +207,7 @@ class DuckDB(Dialect):
"PIVOT_WIDER": TokenType.PIVOT,
"POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"STRING": TokenType.TEXT,
"UBIGINT": TokenType.UBIGINT,
"UINTEGER": TokenType.UINT,
"USMALLINT": TokenType.USMALLINT,
@ -216,6 +216,7 @@ class DuckDB(Dialect):
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
"TIMESTAMP_US": TokenType.TIMESTAMP,
"VARCHAR": TokenType.TEXT,
}
SINGLE_TOKENS = {
@ -312,9 +313,11 @@ class DuckDB(Dialect):
),
}
TYPE_CONVERTER = {
TYPE_CONVERTERS = {
# https://duckdb.org/docs/sql/data_types/numeric
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3),
# https://duckdb.org/docs/sql/data_types/text
exp.DataType.Type.TEXT: lambda dtype: exp.DataType.build("TEXT"),
}
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
@ -495,6 +498,7 @@ class DuckDB(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.BPCHAR: "TEXT",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.NCHAR: "TEXT",

View file

@ -202,6 +202,7 @@ class MySQL(Dialect):
"CHARSET": TokenType.CHARACTER_SET,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"KEY": TokenType.KEY,
"LOCK TABLES": TokenType.COMMAND,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
trim_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import OPTIONS_TYPE
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@ -32,10 +33,171 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope
"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast
def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")
def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)
def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0
if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None
join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")
if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return
# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}
for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent
join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)
_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True
SUPPORTS_COLUMN_JOIN_MARKS = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -70,6 +232,12 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@", "$", "#"}
UNICODE_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
for prefix in ("U", "u")
]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
@ -132,6 +300,7 @@ class Oracle(Dialect):
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
TokenType.WITH: lambda self: ("options", [self._parse_query_restrictions()]),
}
TYPE_LITERAL_PARSERS = {
@ -144,6 +313,13 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
QUERY_RESTRICTIONS: OPTIONS_TYPE = {
"WITH": (
("READ", "ONLY"),
("CHECK", "OPTION"),
),
}
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@ -173,12 +349,6 @@ class Oracle(Dialect):
**kwargs,
)
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
start = self._curr
@ -193,11 +363,22 @@ class Oracle(Dialect):
return None
def _parse_query_restrictions(self) -> t.Optional[exp.Expression]:
kind = self._parse_var_from_options(self.QUERY_RESTRICTIONS, raise_unmatched=False)
if not kind:
return None
return self.expression(
exp.QueryOption,
this=kind,
expression=self._match(TokenType.CONSTRAINT) and self._parse_field(),
)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
@ -282,3 +463,10 @@ class Oracle(Dialect):
if len(expression.args.get("actions", [])) > 1:
return f"ADD ({actions})"
return f"ADD {actions}"
def queryoption_sql(self, expression: exp.QueryOption) -> str:
option = self.sql(expression, "this")
value = self.sql(expression, "expression")
value = f" CONSTRAINT {value}" if value else ""
return f"{option}{value}"

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
Dialect,
JSON_EXTRACT_TYPE,
any_value_to_max_sql,
binary_from_function,
bool_xor_sql,
datestrtodate_sql,
build_formatted_time,
@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
build_json_extract_path,
build_timestamp_trunc,
rename_func,
sha256_sql,
str_position_sql,
struct_extract_sql,
timestamptrunc_sql,
@ -329,6 +331,7 @@ class Postgres(Dialect):
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
KEYWORDS.pop("DIV")
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -347,6 +350,9 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": build_timestamp_trunc,
"DIV": lambda args: exp.cast(
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
),
"GENERATE_SERIES": _build_generate_series,
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
@ -357,6 +363,9 @@ class Postgres(Dialect):
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA384": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(384)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
}
FUNCTION_PARSERS = {
@ -494,6 +503,7 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@ -528,6 +538,7 @@ class Postgres(Dialect):
transforms.eliminate_qualify,
]
),
exp.SHA2: sha256_sql,
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
@ -621,3 +632,12 @@ class Postgres(Dialect):
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return super().datatype_sql(expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
this = expression.this
# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"):
return self.sql(this)
return super().cast_sql(expression, safe_prefix=safe_prefix)

View file

@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
regexp_extract_sql,
rename_func,
right_to_substring_sql,
sha256_sql,
struct_extract_sql,
str_position_sql,
timestamptrunc_sql,
@ -452,9 +453,7 @@ class Presto(Dialect):
),
exp.MD5Digest: rename_func("MD5"),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.SHA2: sha256_sql,
}
RESERVED_KEYWORDS = {

View file

@ -40,6 +40,7 @@ class Redshift(Postgres):
INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True
SUPPORTS_COLUMN_JOIN_MARKS = True
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
@ -122,12 +123,13 @@ class Redshift(Postgres):
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"HLLSKETCH": TokenType.HLLSKETCH,
"MINUS": TokenType.EXCEPT,
"SUPER": TokenType.SUPER,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"MINUS": TokenType.EXCEPT,
}
KEYWORDS.pop("VALUES")
@ -209,6 +211,7 @@ class Redshift(Postgres):
# Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay)
TRANSFORMS.pop(exp.SHA2)
RESERVED_KEYWORDS = {
"aes128",

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
binary_from_function,
build_default_decimal_type,
build_timestamp_from_parts,
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
@ -236,15 +237,6 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return trunc
def _build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
"""
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
@ -391,8 +383,8 @@ class Snowflake(Dialect):
"TIMEDIFF": _build_datediff,
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
"TIMESTAMPFROMPARTS": build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
@ -446,7 +438,7 @@ class Snowflake(Dialect):
"LOCATION": lambda self: self._parse_location_property(),
}
TYPE_CONVERTER = {
TYPE_CONVERTERS = {
# https://docs.snowflake.com/en/sql-reference/data-types-numeric#number
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0),
}
@ -510,15 +502,18 @@ class Snowflake(Dialect):
self._retreat(self._index - 1)
if self._match_text_seq("MASKING", "POLICY"):
policy = self._parse_column()
return self.expression(
exp.MaskingPolicyColumnConstraint,
this=self._parse_id_var(),
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
expressions=self._match(TokenType.USING)
and self._parse_wrapped_csv(self._parse_id_var),
)
if self._match_text_seq("PROJECTION", "POLICY"):
policy = self._parse_column()
return self.expression(
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var()
exp.ProjectionPolicyColumnConstraint,
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
)
if self._match(TokenType.TAG):
return self.expression(

View file

@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
)
def _build_dateadd(args: t.List) -> exp.Expression:
expression = seq_get(args, 1)
if len(args) == 2:
# DATE_ADD(startDate, numDays INTEGER)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
return exp.TsOrDsAdd(
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
)
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
return e
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
):
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
return self.func("DATE_ADD", expression.this, expression.expression)
this = self.func(
"DATE_ADD",
unit_to_var(expression),
expression.expression,
expression.this,
)
if isinstance(expression, exp.TsOrDsAdd):
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
# in other dialects
return_type = expression.return_type
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
this = f"CAST({this} AS {return_type})"
return this
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
@ -62,6 +101,9 @@ class Spark(Spark2):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATE_ADD": _build_dateadd,
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@ -111,9 +153,8 @@ class Spark(Spark2):
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TsOrDsAdd: _dateadd_sql,
exp.TimestampAdd: _dateadd_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),

View file

@ -75,6 +75,26 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
return expression
def _generated_to_auto_increment(expression: exp.Expression) -> exp.Expression:
if not isinstance(expression, exp.ColumnDef):
return expression
generated = expression.find(exp.GeneratedAsIdentityColumnConstraint)
if generated:
t.cast(exp.ColumnConstraint, generated.parent).pop()
not_null = expression.find(exp.NotNullColumnConstraint)
if not_null:
t.cast(exp.ColumnConstraint, not_null.parent).pop()
expression.append(
"constraints", exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint())
)
return expression
class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
@ -141,6 +161,7 @@ class SQLite(Dialect):
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ColumnDef: transforms.preprocess([_generated_to_auto_increment]),
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.If: rename_func("IIF"),

View file

@ -1118,3 +1118,7 @@ class TSQL(Dialect):
kind = f"TABLE {kind}"
return f"{variable} AS {kind}{default}"
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" OPTION{self.wrap(options)}" if options else ""

View file

@ -3119,22 +3119,6 @@ class Intersect(Union):
pass
class Unnest(UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Update(Expression):
arg_types = {
"with": False,
@ -5240,6 +5224,22 @@ class PosexplodeOuter(Posexplode, ExplodeOuter):
pass
class Unnest(Func, UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@ -5765,7 +5765,7 @@ class StrPosition(Func):
class StrToDate(Func):
arg_types = {"this": True, "format": True}
arg_types = {"this": True, "format": False}
class StrToTime(Func):

View file

@ -225,9 +225,6 @@ class Generator(metaclass=_Generator):
# Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True
# Whether to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
@ -359,6 +356,9 @@ class Generator(metaclass=_Generator):
# Whether the conditional TRY(expression) function is supported
TRY_SUPPORTED = True
# Whether the UESCAPE syntax in unicode strings is supported
SUPPORTS_UESCAPE = True
# The keyword to use when generating a star projection with excluded columns
STAR_EXCEPT = "EXCEPT"
@ -827,7 +827,7 @@ class Generator(metaclass=_Generator):
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
@ -1146,16 +1146,23 @@ class Generator(metaclass=_Generator):
escape = expression.args.get("escape")
if self.dialect.UNICODE_START:
escape = f" UESCAPE {self.sql(escape)}" if escape else ""
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
escape_substitute = r"\\\1"
left_quote, right_quote = self.dialect.UNICODE_START, self.dialect.UNICODE_END
else:
escape_substitute = r"\\u\1"
left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END
if escape:
pattern = re.compile(rf"{escape.name}(\d+)")
escape_pattern = re.compile(rf"{escape.name}(\d+)")
escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else ""
else:
pattern = ESCAPED_UNICODE_RE
escape_pattern = ESCAPED_UNICODE_RE
escape_sql = ""
this = pattern.sub(r"\\u\1", this)
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE):
this = escape_pattern.sub(escape_substitute, this)
return f"{left_quote}{this}{right_quote}{escape_sql}"
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
@ -1973,7 +1980,9 @@ class Generator(metaclass=_Generator):
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
if op_sql != "STRAIGHT_JOIN":
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
@ -2235,10 +2244,6 @@ class Generator(metaclass=_Generator):
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
options = self.expressions(expression, key="options")
if options:
options = f" OPTION{self.wrap(options)}"
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
@ -2253,10 +2258,14 @@ class Generator(metaclass=_Generator):
self.sql(expression, "order"),
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
*self.after_limit_modifiers(expression),
options,
self.options_modifier(expression),
sep="",
)
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" {options}" if options else ""
def queryoption_sql(self, expression: exp.QueryOption) -> str:
return ""

View file

@ -1034,7 +1034,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return (
DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
trunc_arg, date, unit, dialect, extract_type(r)
)
or expression
)
@ -1060,7 +1060,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression
ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs)
target_type = extract_type(*rs)
return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False

View file

@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
}
JOIN_KINDS = {
TokenType.ANTI,
TokenType.CROSS,
TokenType.INNER,
TokenType.OUTER,
TokenType.CROSS,
TokenType.SEMI,
TokenType.ANTI,
TokenType.STRAIGHT_JOIN,
}
JOIN_HINTS: t.Set[str] = set()
@ -1065,7 +1066,7 @@ class Parser(metaclass=_Parser):
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
}
TYPE_CONVERTER: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
TYPE_CONVERTERS: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
@ -1138,7 +1139,14 @@ class Parser(metaclass=_Parser):
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
ADD_CONSTRAINT_TOKENS = {
TokenType.CONSTRAINT,
TokenType.FOREIGN_KEY,
TokenType.INDEX,
TokenType.KEY,
TokenType.PRIMARY_KEY,
TokenType.UNIQUE,
}
DISTINCT_TOKENS = {TokenType.DISTINCT}
@ -3099,7 +3107,7 @@ class Parser(metaclass=_Parser):
index = self._index
method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)
if not skip_join_token and not join:
self._retreat(index)
@ -3242,7 +3250,7 @@ class Parser(metaclass=_Parser):
while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
hint = exp.IndexTableHint(this=self._prev.text.upper())
self._match_texts(("INDEX", "KEY"))
self._match_set((TokenType.INDEX, TokenType.KEY))
if self._match(TokenType.FOR):
hint.set("target", self._advance_any() and self._prev.text.upper())
@ -4464,8 +4472,8 @@ class Parser(metaclass=_Parser):
)
self._match(TokenType.R_BRACKET)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this)
if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTERS.get(this.this)
if converter:
this = converter(t.cast(exp.DataType, this))
@ -4496,7 +4504,12 @@ class Parser(metaclass=_Parser):
def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_column_reference()
return self._parse_column_ops(this) if this else self._parse_bracket(this)
column = self._parse_column_ops(this) if this else self._parse_bracket(this)
if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
@ -4522,7 +4535,11 @@ class Parser(metaclass=_Parser):
while self._match(TokenType.COLON):
start_index = self._index
path = self._parse_column_ops(self._parse_field(any_token=True))
# Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True
path = self._parse_column_ops(
self._parse_field(any_token=True, tokens=(TokenType.SELECT,))
)
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the JSON path

View file

@ -287,6 +287,7 @@ class TokenType(AutoName):
JOIN = auto()
JOIN_MARKER = auto()
KEEP = auto()
KEY = auto()
KILL = auto()
LANGUAGE = auto()
LATERAL = auto()
@ -360,6 +361,7 @@ class TokenType(AutoName):
SORT_BY = auto()
START_WITH = auto()
STORAGE_INTEGRATION = auto()
STRAIGHT_JOIN = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TAG = auto()
@ -764,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH,
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
@ -1270,18 +1273,6 @@ class Tokenizer(metaclass=_Tokenizer):
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
if (
self.HEREDOC_TAG_IS_IDENTIFIER
and not self._peek.isidentifier()
and not self._peek == end
):
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
self._add(self.HEREDOC_STRING_ALTERNATIVE)
else:
self._scan_var()
return True
self._advance()
if self._char == end:
@ -1293,7 +1284,10 @@ class Tokenizer(metaclass=_Tokenizer):
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
)
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER:
if tag and self.HEREDOC_TAG_IS_IDENTIFIER and (self._end or not tag.isidentifier()):
if not self._end:
self._advance(-1)
self._advance(-len(tag))
self._add(self.HEREDOC_STRING_ALTERNATIVE)
return True

View file

@ -505,7 +505,10 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
def _ensure_bool(node: exp.Expression) -> None:
if (
node.is_number
or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
or (
not isinstance(node, exp.SubqueryPredicate)
and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
)
or (isinstance(node, exp.Column) and not node.type)
):
node.replace(node.neq(0))

2
sqlglotrs/Cargo.lock generated
View file

@ -188,7 +188,7 @@ checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]]
name = "sqlglotrs"
version = "0.2.5"
version = "0.2.6"
dependencies = [
"pyo3",
]

View file

@ -1,6 +1,6 @@
[package]
name = "sqlglotrs"
version = "0.2.5"
version = "0.2.6"
edition = "2021"
[lib]

View file

@ -405,19 +405,6 @@ impl<'a> TokenizerState<'a> {
} else if *token_type == self.token_types.bit_string {
(Some(2), *token_type, end.clone())
} else if *token_type == self.token_types.heredoc_string {
if self.settings.heredoc_tag_is_identifier
&& !self.is_identifier(self.peek_char)
&& self.peek_char.to_string() != *end
{
if self.token_types.heredoc_string_alternative != self.token_types.var {
self.add(self.token_types.heredoc_string_alternative, None)?
} else {
self.scan_var()?
};
return Ok(true)
};
self.advance(1)?;
let tag = if self.current_char.to_string() == *end {
@ -426,7 +413,14 @@ impl<'a> TokenizerState<'a> {
self.extract_string(end, false, false, !self.settings.heredoc_tag_is_identifier)?
};
if self.is_end && !tag.is_empty() && self.settings.heredoc_tag_is_identifier {
if !tag.is_empty()
&& self.settings.heredoc_tag_is_identifier
&& (self.is_end || !self.is_identifier(&tag))
{
if !self.is_end {
self.advance(-1)?;
}
self.advance(-(tag.len() as isize))?;
self.add(self.token_types.heredoc_string_alternative, None)?;
return Ok(true)
@ -494,7 +488,7 @@ impl<'a> TokenizerState<'a> {
} else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 {
scientific += 1;
self.advance(1)?;
} else if self.is_identifier(self.peek_char) {
} else if self.is_alphabetic_or_underscore(self.peek_char) {
let number_text = self.text();
let mut literal = String::from("");
@ -676,10 +670,18 @@ impl<'a> TokenizerState<'a> {
Ok(text)
}
fn is_identifier(&mut self, name: char) -> bool {
fn is_alphabetic_or_underscore(&mut self, name: char) -> bool {
name.is_alphabetic() || name == '_'
}
fn is_identifier(&mut self, s: &str) -> bool {
s.chars().enumerate().all(
|(i, c)|
if i == 0 { self.is_alphabetic_or_underscore(c) }
else { self.is_alphabetic_or_underscore(c) || c.is_digit(10) }
)
}
fn extract_value(&mut self) -> Result<String, TokenizerError> {
loop {
if !self.peek_char.is_whitespace()

View file

@ -20,6 +20,14 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
self.validate_all(
"EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
write={
"bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
"duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
"snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
},
)
self.validate_identity(
"""CREATE TEMPORARY FUNCTION FOO()
RETURNS STRING
@ -619,9 +627,9 @@ LANGUAGE js AS
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
@ -761,12 +769,15 @@ LANGUAGE js AS
"clickhouse": "SHA256(x)",
"presto": "SHA256(x)",
"trino": "SHA256(x)",
"postgres": "SHA256(x)",
},
write={
"bigquery": "SHA256(x)",
"spark2": "SHA2(x, 256)",
"clickhouse": "SHA256(x)",
"postgres": "SHA256(x)",
"presto": "SHA256(x)",
"redshift": "SHA2(x, 256)",
"trino": "SHA256(x)",
},
)

View file

@ -18,6 +18,13 @@ class TestDuckDB(Validator):
"WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1",
)
self.validate_all(
"SELECT straight_join",
write={
"duckdb": "SELECT straight_join",
"mysql": "SELECT `straight_join`",
},
)
self.validate_all(
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
read={
@ -278,6 +285,7 @@ class TestDuckDB(Validator):
self.validate_identity("FROM tbl", "SELECT * FROM tbl")
self.validate_identity("x -> '$.family'")
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob")
self.validate_identity(
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
)
@ -1000,6 +1008,7 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)")
self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)")
self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)")
@ -1027,6 +1036,13 @@ class TestDuckDB(Validator):
"CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
)
self.validate_all(
"CAST(x AS VARCHAR(5))",
write={
"duckdb": "CAST(x AS TEXT)",
"postgres": "CAST(x AS TEXT)",
},
)
self.validate_all(
"CAST(x AS DECIMAL(38, 0))",
read={

View file

@ -21,6 +21,9 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)")
self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)")
self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT")
self.validate_identity(
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
)
@ -60,6 +63,10 @@ class TestMySQL(Validator):
self.validate_identity(
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
)
self.validate_identity(
"ALTER TABLE t ADD KEY `i` (`c`)",
"ALTER TABLE t ADD INDEX `i` (`c`)",
)
self.validate_identity(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
@ -76,9 +83,6 @@ class TestMySQL(Validator):
"ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT",
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity(
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
@ -113,6 +117,7 @@ class TestMySQL(Validator):
)
def test_identity(self):
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2")

View file

@ -1,5 +1,5 @@
from sqlglot import exp
from sqlglot.errors import UnsupportedError
from sqlglot import exp, UnsupportedError
from sqlglot.dialects.oracle import eliminate_join_marks
from tests.dialects.test_dialect import Validator
@ -43,6 +43,7 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')")
self.validate_identity(
"SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name"
)
@ -249,7 +250,8 @@ class TestOracle(Validator):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError}
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
write={"": UnsupportedError},
)
self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
@ -413,3 +415,65 @@ WHERE
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True)
def test_eliminate_join_marks(self):
test_sql = [
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
),
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
),
(
"SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
),
(
"SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
),
(
"SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
),
# 2 join marks on one side of predicate
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
),
# join mark and expression
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
),
]
for original, expected in test_sql:
with self.subTest(original):
self.assertEqual(
eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
expected,
)
def test_query_restrictions(self):
for restriction in ("READ ONLY", "CHECK OPTION"):
for constraint_name in (" CONSTRAINT name", ""):
with self.subTest(f"Restriction: {restriction}"):
self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}")
self.validate_identity(
f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
)

View file

@ -8,6 +8,7 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_postgres(self):
self.validate_identity("SHA384(x)")
self.validate_identity(
'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
)
@ -724,6 +725,28 @@ class TestPostgres(Validator):
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")
self.validate_all(
"1 / DIV(4, 2)",
read={
"postgres": "1 / DIV(4, 2)",
},
write={
"sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
"duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
"bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
},
)
self.validate_all(
"CAST(DIV(4, 2) AS DECIMAL(5, 3))",
read={
"duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
},
write={
"duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
"postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
},
)
def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(

View file

@ -564,6 +564,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter \\2603 !'",
write={
"oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter \\2603 !'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@ -572,6 +573,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
write={
"oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter #2603 !' UESCAPE '#'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",

View file

@ -281,6 +281,9 @@ class TestRedshift(Validator):
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
"spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
"spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
"databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
},
)
self.validate_all(
@ -585,3 +588,9 @@ FROM (
self.assertEqual(
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
)
def test_join_markers(self):
self.validate_identity(
"select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)",
"SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)",
)

View file

@ -125,6 +125,10 @@ WHERE
"SELECT a:from::STRING, a:from || ' test' ",
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
)
self.validate_identity(
"SELECT a:select",
"SELECT GET_PATH(a, 'select')",
)
self.validate_identity("x:from", "GET_PATH(x, 'from')")
self.validate_identity(
"value:values::string::int",
@ -1196,16 +1200,16 @@ WHERE
for constraint_prefix in ("WITH ", ""):
with self.subTest(f"Constraint prefix: {constraint_prefix}"):
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p)",
"CREATE TABLE t (id INT MASKING POLICY p)",
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)",
"CREATE TABLE t (id INT MASKING POLICY p.q.r)",
)
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))",
"CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))",
)
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p)",
"CREATE TABLE t (id INT PROJECTION POLICY p)",
f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)",
"CREATE TABLE t (id INT PROJECTION POLICY p.q.r)",
)
self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))",

View file

@ -563,6 +563,7 @@ TBLPROPERTIES (
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
"spark2": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
},
)
@ -675,6 +676,16 @@ TBLPROPERTIES (
"spark": "SELECT ARRAY_SORT(x)",
},
)
self.validate_all(
"SELECT DATE_ADD(MONTH, 20, col)",
read={
"spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
},
write={
"spark": "SELECT DATE_ADD(MONTH, 20, col)",
"databricks": "SELECT DATE_ADD(MONTH, 20, col)",
},
)
def test_bool_or(self):
self.validate_all(

View file

@ -202,6 +202,7 @@ class TestSQLite(Validator):
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
read={
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
"postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
},
write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",

View file

@ -1,12 +1,18 @@
from sqlglot import exp, parse, parse_one
from sqlglot import exp, parse
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
from sqlglot.optimizer.annotate_types import annotate_types
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
self.assertEqual(
annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
"SELECT 1 WHERE EXISTS(SELECT 1)",
)
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
@ -217,9 +223,9 @@ class TestTSQL(Validator):
"CREATE TABLE [db].[tbl] ([a] INTEGER)",
)
projection = parse_one("SELECT a = 1", read="tsql").selects[0]
projection.assert_is(exp.Alias)
projection.args["alias"].assert_is(exp.Identifier)
self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
exp.Alias
).args["alias"].assert_is(exp.Identifier)
self.validate_all(
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
@ -756,12 +762,9 @@ class TestTSQL(Validator):
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql")
self.assertIsInstance(expression, exp.AlterTable)
self.assertIsInstance(expression.args["actions"][0], exp.Drop)
self.assertEqual(
expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
)
self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
exp.AlterTable
).args["actions"][0].assert_is(exp.Drop)
for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity(
@ -795,10 +798,10 @@ class TestTSQL(Validator):
)
self.validate_all(
"CREATE TABLE [#temptest] (name VARCHAR)",
"CREATE TABLE [#temptest] (name INTEGER)",
read={
"duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)",
"tsql": "CREATE TABLE [#temptest] (name VARCHAR)",
"duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
"tsql": "CREATE TABLE [#temptest] (name INTEGER)",
},
)
self.validate_all(
@ -1632,27 +1635,23 @@ WHERE
)
def test_identifier_prefixes(self):
expr = parse_one("#x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
self.assertTrue(expr.this.args.get("temporary"))
self.assertEqual(expr.sql("tsql"), "#x")
self.assertTrue(
self.validate_identity("#x")
.assert_is(exp.Column)
.this.assert_is(exp.Identifier)
.args.get("temporary")
)
self.assertTrue(
self.validate_identity("##x")
.assert_is(exp.Column)
.this.assert_is(exp.Identifier)
.args.get("global")
)
expr = parse_one("##x", read="tsql")
self.assertIsInstance(expr, exp.Column)
self.assertIsInstance(expr.this, exp.Identifier)
self.assertTrue(expr.this.args.get("global"))
self.assertEqual(expr.sql("tsql"), "##x")
expr = parse_one("@x", read="tsql")
self.assertIsInstance(expr, exp.Parameter)
self.assertIsInstance(expr.this, exp.Var)
self.assertEqual(expr.sql("tsql"), "@x")
table = parse_one("select * from @x", read="tsql").args["from"].this
self.assertIsInstance(table, exp.Table)
self.assertIsInstance(table.this, exp.Parameter)
self.assertIsInstance(table.this.this, exp.Var)
self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
exp.Table
).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_all(
"SELECT @x",
@ -1663,8 +1662,6 @@ WHERE
"tsql": "SELECT @x",
},
)
def test_temp_table(self):
self.validate_all(
"SELECT * FROM #mytemptable",
write={

View file

@ -872,3 +872,4 @@ SELECT name
SELECT copy
SELECT rollup
SELECT unnest
SELECT * FROM a STRAIGHT_JOIN b

View file

@ -1047,6 +1047,9 @@ x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE);
TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME);
x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);
DATE_TRUNC('day', CAST(x AS DATE)) <= CAST('2021-01-01 01:02:03' AS TIMESTAMP);
CAST(x AS DATE) < CAST('2021-01-02 01:02:03' AS TIMESTAMP);
--------------------------------------
-- EQUALITY
--------------------------------------

View file

@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):
def qualify_columns(expression, **kwargs):
expression = optimizer.qualify.qualify(
expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs
expression,
infer_schema=True,
validate_qualify_columns=False,
identify=False,
**kwargs,
)
return expression
@ -111,7 +115,14 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(
self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs
self,
file,
func,
pretty=False,
execute=False,
set_dialect=False,
only=None,
**kwargs,
):
with ProcessPoolExecutor() as pool:
results = {}
@ -331,7 +342,11 @@ class TestOptimizer(unittest.TestCase):
)
self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
"qualify_columns",
qualify_columns,
execute=True,
schema=self.schema,
set_dialect=True,
)
self.check_file(
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
@ -343,7 +358,8 @@ class TestOptimizer(unittest.TestCase):
def test_pushdown_cte_alias_columns(self):
self.check_file(
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
"pushdown_cte_alias_columns",
optimizer.qualify_columns.pushdown_cte_alias_columns,
)
def test_qualify_columns__invalid(self):
@ -405,7 +421,8 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
this=exp.to_identifier("anonymous"),
expressions=[exp.column("x"), exp.column("y")],
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
@ -416,7 +433,10 @@ class TestOptimizer(unittest.TestCase):
anon_invalid = exp.Anonymous(this=5)
optimizer.simplify.gen(anon_invalid)
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
self.assertIn(
"Anonymous.this expects a str or an Identifier, got 'int'.",
str(e.exception),
)
sql = parse_one(
"""
@ -906,7 +926,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
cte_select.find_all(exp.Subquery),
[exp.DataType.Type.CHAR, exp.DataType.Type.TEXT],
):
self.assertEqual(d.this.expressions[0].this.type.this, t)
@ -1020,7 +1041,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for (func, col), target_type in tests.items():
expression = annotate_types(
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"),
schema=schema,
)
self.assertEqual(expression.expressions[0].type.this, target_type)
@ -1035,7 +1057,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
def test_nested_type_annotation(self):
schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}}
schema = {
"order": {
"customer_id": "bigint",
"item_id": "bigint",
"item_price": "numeric",
}
}
sql = """
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
FROM order AS order
@ -1057,7 +1085,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
self.assertEqual(
expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>"
expression.selects[1].type.sql(dialect="bigquery"),
"ARRAY<STRUCT<`f` STRING>>",
)
expression = annotate_types(
@ -1206,7 +1235,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(
optimizer.optimize(
parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery")
parse_one("SELECT * FROM a"),
schema=MappingSchema(schema, dialect="bigquery"),
),
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
)

View file

@ -106,6 +106,7 @@ class TestParser(unittest.TestCase):
expr = parse_one("SELECT foo IN UNNEST(bla) AS bar")
self.assertIsInstance(expr.selects[0], exp.Alias)
self.assertEqual(expr.selects[0].output_name, "bar")
self.assertIsNotNone(parse_one("select unnest(x)").find(exp.Unnest))
def test_unary_plus(self):
self.assertEqual(parse_one("+15"), exp.Literal.number(15))
@ -880,10 +881,12 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)
def test_trailing_comments(self):
expressions = parse("""
expressions = parse(
"""
select * from x;
-- my comment
""")
"""
)
self.assertEqual(
";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */"