1
0
Fork 0

Merging upstream version 25.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:39:30 +01:00
parent 7ab180cac9
commit 3b7539dcad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
79 changed files with 28803 additions and 24929 deletions

View file

@ -1,6 +1,16 @@
Changelog Changelog
========= =========
## [v25.0.3] - 2024-06-06
### :sparkles: New Features
- [`97f8d1a`](https://github.com/tobymao/sqlglot/commit/97f8d1a05801bcd7fd237dac0470c232d3106ca4) - add materialize dialect *(PR [#3577](https://github.com/tobymao/sqlglot/pull/3577) by [@bobbyiliev](https://github.com/bobbyiliev))*
- [`bde5a8d`](https://github.com/tobymao/sqlglot/commit/bde5a8de346125704f757ed6a2de444905fe146e) - add risingwave dialect *(PR [#3598](https://github.com/tobymao/sqlglot/pull/3598) by [@neverchanje](https://github.com/neverchanje))*
### :recycle: Refactors
- [`5140817`](https://github.com/tobymao/sqlglot/commit/51408172ce940b6ab0ad783d98e632d972da6a0a) - **risingwave**: clean up initial implementation of RisingWave *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`f920014`](https://github.com/tobymao/sqlglot/commit/f920014709c2d3ccb7ec18fb622ecd6b6ee0afcd) - **materialize**: clean up initial implementation of Materialize *(PR [#3608](https://github.com/tobymao/sqlglot/pull/3608) by [@georgesittas](https://github.com/georgesittas))*
## [v25.0.2] - 2024-06-05 ## [v25.0.2] - 2024-06-05
### :sparkles: New Features ### :sparkles: New Features
- [`472058d`](https://github.com/tobymao/sqlglot/commit/472058daccf8dc2a7f7f4b7082309a06802017a5) - **bigquery**: add support for GAP_FILL function *(commit by [@georgesittas](https://github.com/georgesittas))* - [`472058d`](https://github.com/tobymao/sqlglot/commit/472058daccf8dc2a7f7f4b7082309a06802017a5) - **bigquery**: add support for GAP_FILL function *(commit by [@georgesittas](https://github.com/georgesittas))*
@ -3859,3 +3869,4 @@ Changelog
[v24.1.2]: https://github.com/tobymao/sqlglot/compare/v24.1.1...v24.1.2 [v24.1.2]: https://github.com/tobymao/sqlglot/compare/v24.1.1...v24.1.2
[v25.0.0]: https://github.com/tobymao/sqlglot/compare/v24.1.2...v25.0.0 [v25.0.0]: https://github.com/tobymao/sqlglot/compare/v24.1.2...v25.0.0
[v25.0.2]: https://github.com/tobymao/sqlglot/compare/v25.0.1...v25.0.2 [v25.0.2]: https://github.com/tobymao/sqlglot/compare/v25.0.1...v25.0.2
[v25.0.3]: https://github.com/tobymao/sqlglot/compare/v25.0.2...v25.0.3

View file

@ -86,7 +86,7 @@ I tried to parse invalid SQL and it worked, even though it should raise an error
What happened to sqlglot.dataframe? What happened to sqlglot.dataframe?
* The PySpark dataframe api was moved to a standalone library called [sqlframe](https://github.com/eakmanrq/sqlframe) in v24. It now allows you to run queries as opposed to just generate SQL. * The PySpark dataframe api was moved to a standalone library called [SQLFrame](https://github.com/eakmanrq/sqlframe) in v24. It now allows you to run queries as opposed to just generate SQL.
## Examples ## Examples
@ -505,7 +505,7 @@ See also: [Writing a Python SQL engine from scratch](https://github.com/tobymao/
* [Querybook](https://github.com/pinterest/querybook) * [Querybook](https://github.com/pinterest/querybook)
* [Quokka](https://github.com/marsupialtail/quokka) * [Quokka](https://github.com/marsupialtail/quokka)
* [Splink](https://github.com/moj-analytical-services/splink) * [Splink](https://github.com/moj-analytical-services/splink)
* [sqlframe](https://github.com/eakmanrq/sqlframe) * [SQLFrame](https://github.com/eakmanrq/sqlframe)
## Documentation ## Documentation

File diff suppressed because one or more lines are too long

View file

@ -76,8 +76,8 @@
</span><span id="L-12"><a href="#L-12"><span class="linenos">12</span></a><span class="n">__version_tuple__</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span> </span><span id="L-12"><a href="#L-12"><span class="linenos">12</span></a><span class="n">__version_tuple__</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span> </span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span>
</span><span id="L-14"><a href="#L-14"><span class="linenos">14</span></a> </span><span id="L-14"><a href="#L-14"><span class="linenos">14</span></a>
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;25.0.2&#39;</span> </span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;25.0.3&#39;</span>
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">25</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> </span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">25</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
</span></pre></div> </span></pre></div>
@ -97,7 +97,7 @@
<section id="version"> <section id="version">
<div class="attr variable"> <div class="attr variable">
<span class="name">version</span><span class="annotation">: str</span> = <span class="name">version</span><span class="annotation">: str</span> =
<span class="default_value">&#39;25.0.2&#39;</span> <span class="default_value">&#39;25.0.3&#39;</span>
</div> </div>
@ -109,7 +109,7 @@
<section id="version_tuple"> <section id="version_tuple">
<div class="attr variable"> <div class="attr variable">
<span class="name">version_tuple</span><span class="annotation">: object</span> = <span class="name">version_tuple</span><span class="annotation">: object</span> =
<span class="default_value">(25, 0, 2)</span> <span class="default_value">(25, 0, 3)</span>
</div> </div>

View file

@ -43,12 +43,14 @@
<li><a href="dialects/drill.html">drill</a></li> <li><a href="dialects/drill.html">drill</a></li>
<li><a href="dialects/duckdb.html">duckdb</a></li> <li><a href="dialects/duckdb.html">duckdb</a></li>
<li><a href="dialects/hive.html">hive</a></li> <li><a href="dialects/hive.html">hive</a></li>
<li><a href="dialects/materialize.html">materialize</a></li>
<li><a href="dialects/mysql.html">mysql</a></li> <li><a href="dialects/mysql.html">mysql</a></li>
<li><a href="dialects/oracle.html">oracle</a></li> <li><a href="dialects/oracle.html">oracle</a></li>
<li><a href="dialects/postgres.html">postgres</a></li> <li><a href="dialects/postgres.html">postgres</a></li>
<li><a href="dialects/presto.html">presto</a></li> <li><a href="dialects/presto.html">presto</a></li>
<li><a href="dialects/prql.html">prql</a></li> <li><a href="dialects/prql.html">prql</a></li>
<li><a href="dialects/redshift.html">redshift</a></li> <li><a href="dialects/redshift.html">redshift</a></li>
<li><a href="dialects/risingwave.html">risingwave</a></li>
<li><a href="dialects/snowflake.html">snowflake</a></li> <li><a href="dialects/snowflake.html">snowflake</a></li>
<li><a href="dialects/spark.html">spark</a></li> <li><a href="dialects/spark.html">spark</a></li>
<li><a href="dialects/spark2.html">spark2</a></li> <li><a href="dialects/spark2.html">spark2</a></li>
@ -212,21 +214,23 @@ dialect implementations in order to understand how their various components can
</span><span id="L-70"><a href="#L-70"><span class="linenos">70</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.drill</span> <span class="kn">import</span> <span class="n">Drill</span> </span><span id="L-70"><a href="#L-70"><span class="linenos">70</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.drill</span> <span class="kn">import</span> <span class="n">Drill</span>
</span><span id="L-71"><a href="#L-71"><span class="linenos">71</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.duckdb</span> <span class="kn">import</span> <span class="n">DuckDB</span> </span><span id="L-71"><a href="#L-71"><span class="linenos">71</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.duckdb</span> <span class="kn">import</span> <span class="n">DuckDB</span>
</span><span id="L-72"><a href="#L-72"><span class="linenos">72</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.hive</span> <span class="kn">import</span> <span class="n">Hive</span> </span><span id="L-72"><a href="#L-72"><span class="linenos">72</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.hive</span> <span class="kn">import</span> <span class="n">Hive</span>
</span><span id="L-73"><a href="#L-73"><span class="linenos">73</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.mysql</span> <span class="kn">import</span> <span class="n">MySQL</span> </span><span id="L-73"><a href="#L-73"><span class="linenos">73</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.materialize</span> <span class="kn">import</span> <span class="n">Materialize</span>
</span><span id="L-74"><a href="#L-74"><span class="linenos">74</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.oracle</span> <span class="kn">import</span> <span class="n">Oracle</span> </span><span id="L-74"><a href="#L-74"><span class="linenos">74</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.mysql</span> <span class="kn">import</span> <span class="n">MySQL</span>
</span><span id="L-75"><a href="#L-75"><span class="linenos">75</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.postgres</span> <span class="kn">import</span> <span class="n">Postgres</span> </span><span id="L-75"><a href="#L-75"><span class="linenos">75</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.oracle</span> <span class="kn">import</span> <span class="n">Oracle</span>
</span><span id="L-76"><a href="#L-76"><span class="linenos">76</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.presto</span> <span class="kn">import</span> <span class="n">Presto</span> </span><span id="L-76"><a href="#L-76"><span class="linenos">76</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.postgres</span> <span class="kn">import</span> <span class="n">Postgres</span>
</span><span id="L-77"><a href="#L-77"><span class="linenos">77</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.prql</span> <span class="kn">import</span> <span class="n">PRQL</span> </span><span id="L-77"><a href="#L-77"><span class="linenos">77</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.presto</span> <span class="kn">import</span> <span class="n">Presto</span>
</span><span id="L-78"><a href="#L-78"><span class="linenos">78</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.redshift</span> <span class="kn">import</span> <span class="n">Redshift</span> </span><span id="L-78"><a href="#L-78"><span class="linenos">78</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.prql</span> <span class="kn">import</span> <span class="n">PRQL</span>
</span><span id="L-79"><a href="#L-79"><span class="linenos">79</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.snowflake</span> <span class="kn">import</span> <span class="n">Snowflake</span> </span><span id="L-79"><a href="#L-79"><span class="linenos">79</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.redshift</span> <span class="kn">import</span> <span class="n">Redshift</span>
</span><span id="L-80"><a href="#L-80"><span class="linenos">80</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark</span> <span class="kn">import</span> <span class="n">Spark</span> </span><span id="L-80"><a href="#L-80"><span class="linenos">80</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.risingwave</span> <span class="kn">import</span> <span class="n">RisingWave</span>
</span><span id="L-81"><a href="#L-81"><span class="linenos">81</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark2</span> <span class="kn">import</span> <span class="n">Spark2</span> </span><span id="L-81"><a href="#L-81"><span class="linenos">81</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.snowflake</span> <span class="kn">import</span> <span class="n">Snowflake</span>
</span><span id="L-82"><a href="#L-82"><span class="linenos">82</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.sqlite</span> <span class="kn">import</span> <span class="n">SQLite</span> </span><span id="L-82"><a href="#L-82"><span class="linenos">82</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark</span> <span class="kn">import</span> <span class="n">Spark</span>
</span><span id="L-83"><a href="#L-83"><span class="linenos">83</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.starrocks</span> <span class="kn">import</span> <span class="n">StarRocks</span> </span><span id="L-83"><a href="#L-83"><span class="linenos">83</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.spark2</span> <span class="kn">import</span> <span class="n">Spark2</span>
</span><span id="L-84"><a href="#L-84"><span class="linenos">84</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tableau</span> <span class="kn">import</span> <span class="n">Tableau</span> </span><span id="L-84"><a href="#L-84"><span class="linenos">84</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.sqlite</span> <span class="kn">import</span> <span class="n">SQLite</span>
</span><span id="L-85"><a href="#L-85"><span class="linenos">85</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.teradata</span> <span class="kn">import</span> <span class="n">Teradata</span> </span><span id="L-85"><a href="#L-85"><span class="linenos">85</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.starrocks</span> <span class="kn">import</span> <span class="n">StarRocks</span>
</span><span id="L-86"><a href="#L-86"><span class="linenos">86</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.trino</span> <span class="kn">import</span> <span class="n">Trino</span> </span><span id="L-86"><a href="#L-86"><span class="linenos">86</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tableau</span> <span class="kn">import</span> <span class="n">Tableau</span>
</span><span id="L-87"><a href="#L-87"><span class="linenos">87</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tsql</span> <span class="kn">import</span> <span class="n">TSQL</span> </span><span id="L-87"><a href="#L-87"><span class="linenos">87</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.teradata</span> <span class="kn">import</span> <span class="n">Teradata</span>
</span><span id="L-88"><a href="#L-88"><span class="linenos">88</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.trino</span> <span class="kn">import</span> <span class="n">Trino</span>
</span><span id="L-89"><a href="#L-89"><span class="linenos">89</span></a><span class="kn">from</span> <span class="nn">sqlglot.dialects.tsql</span> <span class="kn">import</span> <span class="n">TSQL</span>
</span></pre></div> </span></pre></div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -10073,7 +10073,7 @@ Default: True</li>
<div class="attr variable"> <div class="attr variable">
<span class="name">SUPPORTED_JSON_PATH_PARTS</span> = <span class="name">SUPPORTED_JSON_PATH_PARTS</span> =
<input id="Generator.SUPPORTED_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1"> <input id="Generator.SUPPORTED_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="Generator.SUPPORTED_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;}</span> <label class="view-value-button pdoc-button" for="Generator.SUPPORTED_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;}</span>
</div> </div>
@ -10355,7 +10355,7 @@ Default: True</li>
<div id="Generator.PARAMETERIZABLE_TEXT_TYPES" class="classattr"> <div id="Generator.PARAMETERIZABLE_TEXT_TYPES" class="classattr">
<div class="attr variable"> <div class="attr variable">
<span class="name">PARAMETERIZABLE_TEXT_TYPES</span> = <span class="name">PARAMETERIZABLE_TEXT_TYPES</span> =
<span class="default_value">{&lt;Type.CHAR: &#39;CHAR&#39;&gt;, &lt;Type.NVARCHAR: &#39;NVARCHAR&#39;&gt;, &lt;Type.VARCHAR: &#39;VARCHAR&#39;&gt;, &lt;Type.NCHAR: &#39;NCHAR&#39;&gt;}</span> <span class="default_value">{&lt;Type.CHAR: &#39;CHAR&#39;&gt;, &lt;Type.NVARCHAR: &#39;NVARCHAR&#39;&gt;, &lt;Type.NCHAR: &#39;NCHAR&#39;&gt;, &lt;Type.VARCHAR: &#39;VARCHAR&#39;&gt;}</span>
</div> </div>

View file

@ -1893,7 +1893,7 @@ belong to some totally-ordered set.</p>
<section id="DATE_UNITS"> <section id="DATE_UNITS">
<div class="attr variable"> <div class="attr variable">
<span class="name">DATE_UNITS</span> = <span class="name">DATE_UNITS</span> =
<span class="default_value">{&#39;year&#39;, &#39;month&#39;, &#39;quarter&#39;, &#39;day&#39;, &#39;year_month&#39;, &#39;week&#39;}</span> <span class="default_value">{&#39;quarter&#39;, &#39;month&#39;, &#39;year&#39;, &#39;year_month&#39;, &#39;week&#39;, &#39;day&#39;}</span>
</div> </div>

View file

@ -577,7 +577,7 @@
<div class="attr variable"> <div class="attr variable">
<span class="name">ALL_JSON_PATH_PARTS</span> = <span class="name">ALL_JSON_PATH_PARTS</span> =
<input id="ALL_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1"> <input id="ALL_JSON_PATH_PARTS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;}</span> <label class="view-value-button pdoc-button" for="ALL_JSON_PATH_PARTS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="expressions.html#JSONPathSubscript">sqlglot.expressions.JSONPathSubscript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSelector">sqlglot.expressions.JSONPathSelector</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathUnion">sqlglot.expressions.JSONPathUnion</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathSlice">sqlglot.expressions.JSONPathSlice</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathScript">sqlglot.expressions.JSONPathScript</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRoot">sqlglot.expressions.JSONPathRoot</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathWildcard">sqlglot.expressions.JSONPathWildcard</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathRecursive">sqlglot.expressions.JSONPathRecursive</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathKey">sqlglot.expressions.JSONPathKey</a>&#39;&gt;, &lt;class &#39;<a href="expressions.html#JSONPathFilter">sqlglot.expressions.JSONPathFilter</a>&#39;&gt;}</span>
</div> </div>

File diff suppressed because one or more lines are too long

View file

@ -586,7 +586,7 @@ queries if it would result in multiple table selects in a single query:</p>
<div class="attr variable"> <div class="attr variable">
<span class="name">UNMERGABLE_ARGS</span> = <span class="name">UNMERGABLE_ARGS</span> =
<input id="UNMERGABLE_ARGS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1"> <input id="UNMERGABLE_ARGS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;sample&#39;, &#39;prewhere&#39;, &#39;offset&#39;, &#39;group&#39;, &#39;with&#39;, &#39;laterals&#39;, &#39;kind&#39;, &#39;distinct&#39;, &#39;having&#39;, &#39;sort&#39;, &#39;cluster&#39;, &#39;limit&#39;, &#39;format&#39;, &#39;locks&#39;, &#39;distribute&#39;, &#39;settings&#39;, &#39;match&#39;, &#39;connect&#39;, &#39;qualify&#39;, &#39;options&#39;, &#39;windows&#39;, &#39;into&#39;, &#39;pivots&#39;}</span> <label class="view-value-button pdoc-button" for="UNMERGABLE_ARGS-view-value"></label><span class="default_value">{&#39;prewhere&#39;, &#39;locks&#39;, &#39;having&#39;, &#39;distinct&#39;, &#39;into&#39;, &#39;limit&#39;, &#39;match&#39;, &#39;options&#39;, &#39;cluster&#39;, &#39;connect&#39;, &#39;laterals&#39;, &#39;windows&#39;, &#39;qualify&#39;, &#39;offset&#39;, &#39;pivots&#39;, &#39;sort&#39;, &#39;group&#39;, &#39;format&#39;, &#39;with&#39;, &#39;distribute&#39;, &#39;sample&#39;, &#39;kind&#39;, &#39;settings&#39;}</span>
</div> </div>

View file

@ -3220,7 +3220,7 @@ prefix are statically known.</p>
<div class="attr variable"> <div class="attr variable">
<span class="name">DATETRUNC_COMPARISONS</span> = <span class="name">DATETRUNC_COMPARISONS</span> =
<input id="DATETRUNC_COMPARISONS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1"> <input id="DATETRUNC_COMPARISONS-view-value" class="view-value-toggle-state" type="checkbox" aria-hidden="true" tabindex="-1">
<label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;}</span> <label class="view-value-button pdoc-button" for="DATETRUNC_COMPARISONS-view-value"></label><span class="default_value">{&lt;class &#39;<a href="../expressions.html#LT">sqlglot.expressions.LT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#NEQ">sqlglot.expressions.NEQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#EQ">sqlglot.expressions.EQ</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GTE">sqlglot.expressions.GTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#LTE">sqlglot.expressions.LTE</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#GT">sqlglot.expressions.GT</a>&#39;&gt;, &lt;class &#39;<a href="../expressions.html#In">sqlglot.expressions.In</a>&#39;&gt;}</span>
</div> </div>
@ -3300,7 +3300,7 @@ prefix are statically known.</p>
<section id="JOINS"> <section id="JOINS">
<div class="attr variable"> <div class="attr variable">
<span class="name">JOINS</span> = <span class="name">JOINS</span> =
<span class="default_value">{(&#39;RIGHT&#39;, &#39;&#39;), (&#39;&#39;, &#39;INNER&#39;), (&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;&#39;, &#39;&#39;)}</span> <span class="default_value">{(&#39;RIGHT&#39;, &#39;&#39;), (&#39;RIGHT&#39;, &#39;OUTER&#39;), (&#39;&#39;, &#39;INNER&#39;), (&#39;&#39;, &#39;&#39;)}</span>
</div> </div>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
date_add_interval_sql, date_add_interval_sql,
datestrtodate_sql, datestrtodate_sql,
build_formatted_time, build_formatted_time,
build_timestamp_from_parts,
filter_array_using_unnest, filter_array_using_unnest,
if_sql, if_sql,
inline_array_unless_query, inline_array_unless_query,
@ -22,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval, build_date_delta_with_interval,
regexp_replace_sql, regexp_replace_sql,
rename_func, rename_func,
sha256_sql,
timestrtotime_sql, timestrtotime_sql,
ts_or_ds_add_cast, ts_or_ds_add_cast,
unit_to_var, unit_to_var,
@ -321,6 +323,7 @@ class BigQuery(Dialect):
unit=exp.Literal.string(str(seq_get(args, 1))), unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0), this=seq_get(args, 0),
), ),
"DATETIME": build_timestamp_from_parts,
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv), "DIV": binary_from_function(exp.IntDiv),
@ -637,9 +640,7 @@ class BigQuery(Dialect):
] ]
), ),
exp.SHA: rename_func("SHA1"), exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func( exp.SHA2: sha256_sql,
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: ( exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
), ),
@ -649,6 +650,7 @@ class BigQuery(Dialect):
), ),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeFromParts: rename_func("TIME"), exp.TimeFromParts: rename_func("TIME"),
exp.TimestampFromParts: rename_func("DATETIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"), exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),

View file

@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql, no_pivot_sql,
build_json_extract_path, build_json_extract_path,
rename_func, rename_func,
sha256_sql,
var_map_sql, var_map_sql,
timestamptrunc_sql, timestamptrunc_sql,
) )
@ -758,9 +759,7 @@ class ClickHouse(Dialect):
exp.MD5Digest: rename_func("MD5"), exp.MD5Digest: rename_func("MD5"),
exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))), exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))),
exp.SHA: rename_func("SHA1"), exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func( exp.SHA2: sha256_sql,
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.UnixToTime: _unix_to_time_sql, exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True), exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Variance: rename_func("varSamp"), exp.Variance: rename_func("varSamp"),

View file

@ -169,6 +169,7 @@ class _Dialect(type):
if enum not in ("", "athena", "presto", "trino"): if enum not in ("", "athena", "presto", "trino"):
klass.generator_class.TRY_SUPPORTED = False klass.generator_class.TRY_SUPPORTED = False
klass.generator_class.SUPPORTS_UESCAPE = False
if enum not in ("", "databricks", "hive", "spark", "spark2"): if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
@ -177,6 +178,14 @@ class _Dialect(type):
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
if enum not in ("", "doris", "mysql"):
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
if not klass.SUPPORTS_SEMI_ANTI_JOIN: if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI, TokenType.ANTI,
@ -220,6 +229,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_SEMI_ANTI_JOIN = True SUPPORTS_SEMI_ANTI_JOIN = True
"""Whether `SEMI` or `ANTI` joins are supported.""" """Whether `SEMI` or `ANTI` joins are supported."""
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""
NORMALIZE_FUNCTIONS: bool | str = "upper" NORMALIZE_FUNCTIONS: bool | str = "upper"
""" """
Determines how function names are going to be normalized. Determines how function names are going to be normalized.
@ -1178,3 +1190,16 @@ def build_default_decimal_type(
return exp.DataType.build(f"DECIMAL({params})") return exp.DataType.build(f"DECIMAL({params})")
return _builder return _builder
def build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def sha256_sql(self: Generator, expression: exp.SHA2) -> str:
return self.func(f"SHA{expression.text('length') or '256'}", expression.this)

View file

@ -207,7 +207,7 @@ class DuckDB(Dialect):
"PIVOT_WIDER": TokenType.PIVOT, "PIVOT_WIDER": TokenType.PIVOT,
"POSITIONAL": TokenType.POSITIONAL, "POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT, "SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR, "STRING": TokenType.TEXT,
"UBIGINT": TokenType.UBIGINT, "UBIGINT": TokenType.UBIGINT,
"UINTEGER": TokenType.UINT, "UINTEGER": TokenType.UINT,
"USMALLINT": TokenType.USMALLINT, "USMALLINT": TokenType.USMALLINT,
@ -216,6 +216,7 @@ class DuckDB(Dialect):
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS, "TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS, "TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
"TIMESTAMP_US": TokenType.TIMESTAMP, "TIMESTAMP_US": TokenType.TIMESTAMP,
"VARCHAR": TokenType.TEXT,
} }
SINGLE_TOKENS = { SINGLE_TOKENS = {
@ -312,9 +313,11 @@ class DuckDB(Dialect):
), ),
} }
TYPE_CONVERTER = { TYPE_CONVERTERS = {
# https://duckdb.org/docs/sql/data_types/numeric # https://duckdb.org/docs/sql/data_types/numeric
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3), exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3),
# https://duckdb.org/docs/sql/data_types/text
exp.DataType.Type.TEXT: lambda dtype: exp.DataType.build("TEXT"),
} }
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]: def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
@ -495,6 +498,7 @@ class DuckDB(Dialect):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.BPCHAR: "TEXT",
exp.DataType.Type.CHAR: "TEXT", exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.NCHAR: "TEXT", exp.DataType.Type.NCHAR: "TEXT",

View file

@ -202,6 +202,7 @@ class MySQL(Dialect):
"CHARSET": TokenType.CHARACTER_SET, "CHARSET": TokenType.CHARACTER_SET,
"FORCE": TokenType.FORCE, "FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE, "IGNORE": TokenType.IGNORE,
"KEY": TokenType.KEY,
"LOCK TABLES": TokenType.COMMAND, "LOCK TABLES": TokenType.COMMAND,
"LONGBLOB": TokenType.LONGBLOB, "LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT, "LONGTEXT": TokenType.LONGTEXT,

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
trim_sql, trim_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.parser import OPTIONS_TYPE
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -32,10 +33,171 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args) return exp.ToChar.from_arg_list(args)
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope
"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast
def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")
def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)
def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0
if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None
join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")
if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return
# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}
for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent
join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)
_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()
class Oracle(Dialect): class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True TABLESAMPLE_SIZE_IS_PERCENT = True
SUPPORTS_COLUMN_JOIN_MARKS = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -70,6 +232,12 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@", "$", "#"} VAR_SINGLE_TOKENS = {"@", "$", "#"}
UNICODE_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
for prefix in ("U", "u")
]
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER, "(+)": TokenType.JOIN_MARKER,
@ -132,6 +300,7 @@ class Oracle(Dialect):
QUERY_MODIFIER_PARSERS = { QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS, **parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()), TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
TokenType.WITH: lambda self: ("options", [self._parse_query_restrictions()]),
} }
TYPE_LITERAL_PARSERS = { TYPE_LITERAL_PARSERS = {
@ -144,6 +313,13 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455 # Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
QUERY_RESTRICTIONS: OPTIONS_TYPE = {
"WITH": (
("READ", "ONLY"),
("CHECK", "OPTION"),
),
}
def _parse_xml_table(self) -> exp.XMLTable: def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string() this = self._parse_string()
@ -173,12 +349,6 @@ class Oracle(Dialect):
**kwargs, **kwargs,
) )
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_hint(self) -> t.Optional[exp.Hint]: def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT): if self._match(TokenType.HINT):
start = self._curr start = self._curr
@ -193,11 +363,22 @@ class Oracle(Dialect):
return None return None
def _parse_query_restrictions(self) -> t.Optional[exp.Expression]:
kind = self._parse_var_from_options(self.QUERY_RESTRICTIONS, raise_unmatched=False)
if not kind:
return None
return self.expression(
exp.QueryOption,
this=kind,
expression=self._match(TokenType.CONSTRAINT) and self._parse_field(),
)
class Generator(generator.Generator): class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False JOIN_HINTS = False
TABLE_HINTS = False TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH" LIMIT_FETCH = "FETCH"
@ -282,3 +463,10 @@ class Oracle(Dialect):
if len(expression.args.get("actions", [])) > 1: if len(expression.args.get("actions", [])) > 1:
return f"ADD ({actions})" return f"ADD ({actions})"
return f"ADD {actions}" return f"ADD {actions}"
def queryoption_sql(self, expression: exp.QueryOption) -> str:
option = self.sql(expression, "this")
value = self.sql(expression, "expression")
value = f" CONSTRAINT {value}" if value else ""
return f"{option}{value}"

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
Dialect, Dialect,
JSON_EXTRACT_TYPE, JSON_EXTRACT_TYPE,
any_value_to_max_sql, any_value_to_max_sql,
binary_from_function,
bool_xor_sql, bool_xor_sql,
datestrtodate_sql, datestrtodate_sql,
build_formatted_time, build_formatted_time,
@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
build_json_extract_path, build_json_extract_path,
build_timestamp_trunc, build_timestamp_trunc,
rename_func, rename_func,
sha256_sql,
str_position_sql, str_position_sql,
struct_extract_sql, struct_extract_sql,
timestamptrunc_sql, timestamptrunc_sql,
@ -329,6 +331,7 @@ class Postgres(Dialect):
"REGTYPE": TokenType.OBJECT_IDENTIFIER, "REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE, "FLOAT": TokenType.DOUBLE,
} }
KEYWORDS.pop("DIV")
SINGLE_TOKENS = { SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS, **tokens.Tokenizer.SINGLE_TOKENS,
@ -347,6 +350,9 @@ class Postgres(Dialect):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"DATE_TRUNC": build_timestamp_trunc, "DATE_TRUNC": build_timestamp_trunc,
"DIV": lambda args: exp.cast(
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
),
"GENERATE_SERIES": _build_generate_series, "GENERATE_SERIES": _build_generate_series,
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract), "JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar), "JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
@ -357,6 +363,9 @@ class Postgres(Dialect):
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"), "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _build_to_timestamp, "TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list, "UNNEST": exp.Explode.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA384": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(384)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -494,6 +503,7 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"), exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"), exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql, exp.GroupConcat: _string_agg_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"), exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"), exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
@ -528,6 +538,7 @@ class Postgres(Dialect):
transforms.eliminate_qualify, transforms.eliminate_qualify,
] ]
), ),
exp.SHA2: sha256_sql,
exp.StrPosition: str_position_sql, exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
@ -621,3 +632,12 @@ class Postgres(Dialect):
return f"{self.expressions(expression, flat=True)}[{values}]" return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY" return "ARRAY"
return super().datatype_sql(expression) return super().datatype_sql(expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
this = expression.this
# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"):
return self.sql(this)
return super().cast_sql(expression, safe_prefix=safe_prefix)

View file

@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
regexp_extract_sql, regexp_extract_sql,
rename_func, rename_func,
right_to_substring_sql, right_to_substring_sql,
sha256_sql,
struct_extract_sql, struct_extract_sql,
str_position_sql, str_position_sql,
timestamptrunc_sql, timestamptrunc_sql,
@ -452,9 +453,7 @@ class Presto(Dialect):
), ),
exp.MD5Digest: rename_func("MD5"), exp.MD5Digest: rename_func("MD5"),
exp.SHA: rename_func("SHA1"), exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func( exp.SHA2: sha256_sql,
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
} }
RESERVED_KEYWORDS = { RESERVED_KEYWORDS = {

View file

@ -40,6 +40,7 @@ class Redshift(Postgres):
INDEX_OFFSET = 0 INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True HEX_LOWERCASE = True
SUPPORTS_COLUMN_JOIN_MARKS = True
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = { TIME_MAPPING = {
@ -122,12 +123,13 @@ class Redshift(Postgres):
KEYWORDS = { KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, **Postgres.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"HLLSKETCH": TokenType.HLLSKETCH, "HLLSKETCH": TokenType.HLLSKETCH,
"MINUS": TokenType.EXCEPT,
"SUPER": TokenType.SUPER, "SUPER": TokenType.SUPER,
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND, "UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY, "VARBYTE": TokenType.VARBINARY,
"MINUS": TokenType.EXCEPT,
} }
KEYWORDS.pop("VALUES") KEYWORDS.pop("VALUES")
@ -209,6 +211,7 @@ class Redshift(Postgres):
# Redshift supports LAST_DAY(..) # Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay) TRANSFORMS.pop(exp.LastDay)
TRANSFORMS.pop(exp.SHA2)
RESERVED_KEYWORDS = { RESERVED_KEYWORDS = {
"aes128", "aes128",

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy, NormalizationStrategy,
binary_from_function, binary_from_function,
build_default_decimal_type, build_default_decimal_type,
build_timestamp_from_parts,
date_delta_sql, date_delta_sql,
date_trunc_to_time, date_trunc_to_time,
datestrtodate_sql, datestrtodate_sql,
@ -236,15 +237,6 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return trunc return trunc
def _build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
return exp.TimestampFromParts.from_arg_list(args)
def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
""" """
Snowflake doesn't allow columns referenced in UNPIVOT to be qualified, Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
@ -391,8 +383,8 @@ class Snowflake(Dialect):
"TIMEDIFF": _build_datediff, "TIMEDIFF": _build_datediff,
"TIMESTAMPADD": _build_date_time_add(exp.DateAdd), "TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff, "TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts, "TIMESTAMPFROMPARTS": build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, "TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber( "TO_NUMBER": lambda args: exp.ToNumber(
@ -446,7 +438,7 @@ class Snowflake(Dialect):
"LOCATION": lambda self: self._parse_location_property(), "LOCATION": lambda self: self._parse_location_property(),
} }
TYPE_CONVERTER = { TYPE_CONVERTERS = {
# https://docs.snowflake.com/en/sql-reference/data-types-numeric#number # https://docs.snowflake.com/en/sql-reference/data-types-numeric#number
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0), exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0),
} }
@ -510,15 +502,18 @@ class Snowflake(Dialect):
self._retreat(self._index - 1) self._retreat(self._index - 1)
if self._match_text_seq("MASKING", "POLICY"): if self._match_text_seq("MASKING", "POLICY"):
policy = self._parse_column()
return self.expression( return self.expression(
exp.MaskingPolicyColumnConstraint, exp.MaskingPolicyColumnConstraint,
this=self._parse_id_var(), this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
expressions=self._match(TokenType.USING) expressions=self._match(TokenType.USING)
and self._parse_wrapped_csv(self._parse_id_var), and self._parse_wrapped_csv(self._parse_id_var),
) )
if self._match_text_seq("PROJECTION", "POLICY"): if self._match_text_seq("PROJECTION", "POLICY"):
policy = self._parse_column()
return self.expression( return self.expression(
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var() exp.ProjectionPolicyColumnConstraint,
this=policy.to_dot() if isinstance(policy, exp.Column) else policy,
) )
if self._match(TokenType.TAG): if self._match(TokenType.TAG):
return self.expression( return self.expression(

View file

@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
) )
def _build_dateadd(args: t.List) -> exp.Expression:
expression = seq_get(args, 1)
if len(args) == 2:
# DATE_ADD(startDate, numDays INTEGER)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
return exp.TsOrDsAdd(
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
)
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
def _normalize_partition(e: exp.Expression) -> exp.Expression: def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)""" """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str): if isinstance(e, str):
@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
return e return e
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
):
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
return self.func("DATE_ADD", expression.this, expression.expression)
this = self.func(
"DATE_ADD",
unit_to_var(expression),
expression.expression,
expression.this,
)
if isinstance(expression, exp.TsOrDsAdd):
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
# in other dialects
return_type = expression.return_type
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
this = f"CAST({this} AS {return_type})"
return this
class Spark(Spark2): class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer): class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [ RAW_STRINGS = [
@ -62,6 +101,9 @@ class Spark(Spark2):
FUNCTIONS = { FUNCTIONS = {
**Spark2.Parser.FUNCTIONS, **Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATE_ADD": _build_dateadd,
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff, "DATEDIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@ -111,9 +153,8 @@ class Spark(Spark2):
exp.PartitionedByProperty: lambda self, exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"), exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func( exp.TsOrDsAdd: _dateadd_sql,
"DATEADD", unit_to_var(e), e.expression, e.this exp.TimestampAdd: _dateadd_sql,
),
exp.TryCast: lambda self, e: ( exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
), ),

View file

@ -75,6 +75,26 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
return expression return expression
def _generated_to_auto_increment(expression: exp.Expression) -> exp.Expression:
if not isinstance(expression, exp.ColumnDef):
return expression
generated = expression.find(exp.GeneratedAsIdentityColumnConstraint)
if generated:
t.cast(exp.ColumnConstraint, generated.parent).pop()
not_null = expression.find(exp.NotNullColumnConstraint)
if not_null:
t.cast(exp.ColumnConstraint, not_null.parent).pop()
expression.append(
"constraints", exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint())
)
return expression
class SQLite(Dialect): class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
@ -141,6 +161,7 @@ class SQLite(Dialect):
exp.CurrentDate: lambda *_: "CURRENT_DATE", exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME", exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ColumnDef: transforms.preprocess([_generated_to_auto_increment]),
exp.DateAdd: _date_add_sql, exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.If: rename_func("IIF"), exp.If: rename_func("IIF"),

View file

@ -1118,3 +1118,7 @@ class TSQL(Dialect):
kind = f"TABLE {kind}" kind = f"TABLE {kind}"
return f"{variable} AS {kind}{default}" return f"{variable} AS {kind}{default}"
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" OPTION{self.wrap(options)}" if options else ""

View file

@ -3119,22 +3119,6 @@ class Intersect(Union):
pass pass
class Unnest(UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Update(Expression): class Update(Expression):
arg_types = { arg_types = {
"with": False, "with": False,
@ -5240,6 +5224,22 @@ class PosexplodeOuter(Posexplode, ExplodeOuter):
pass pass
class Unnest(Func, UDTF):
arg_types = {
"expressions": True,
"alias": False,
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Floor(Func): class Floor(Func):
arg_types = {"this": True, "decimals": False} arg_types = {"this": True, "decimals": False}
@ -5765,7 +5765,7 @@ class StrPosition(Func):
class StrToDate(Func): class StrToDate(Func):
arg_types = {"this": True, "format": True} arg_types = {"this": True, "format": False}
class StrToTime(Func): class StrToTime(Func):

View file

@ -225,9 +225,6 @@ class Generator(metaclass=_Generator):
# Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... # Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True RETURNING_END = True
# Whether to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
# Whether to generate an unquoted value for EXTRACT's date part argument # Whether to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True EXTRACT_ALLOWS_QUOTES = True
@ -359,6 +356,9 @@ class Generator(metaclass=_Generator):
# Whether the conditional TRY(expression) function is supported # Whether the conditional TRY(expression) function is supported
TRY_SUPPORTED = True TRY_SUPPORTED = True
# Whether the UESCAPE syntax in unicode strings is supported
SUPPORTS_UESCAPE = True
# The keyword to use when generating a star projection with excluded columns # The keyword to use when generating a star projection with excluded columns
STAR_EXCEPT = "EXCEPT" STAR_EXCEPT = "EXCEPT"
@ -827,7 +827,7 @@ class Generator(metaclass=_Generator):
def column_sql(self, expression: exp.Column) -> str: def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else "" join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED: if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS:
join_mark = "" join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.") self.unsupported("Outer join syntax using the (+) operator is not supported.")
@ -1146,16 +1146,23 @@ class Generator(metaclass=_Generator):
escape = expression.args.get("escape") escape = expression.args.get("escape")
if self.dialect.UNICODE_START: if self.dialect.UNICODE_START:
escape = f" UESCAPE {self.sql(escape)}" if escape else "" escape_substitute = r"\\\1"
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}" left_quote, right_quote = self.dialect.UNICODE_START, self.dialect.UNICODE_END
else:
escape_substitute = r"\\u\1"
left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END
if escape: if escape:
pattern = re.compile(rf"{escape.name}(\d+)") escape_pattern = re.compile(rf"{escape.name}(\d+)")
escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else ""
else: else:
pattern = ESCAPED_UNICODE_RE escape_pattern = ESCAPED_UNICODE_RE
escape_sql = ""
this = pattern.sub(r"\\u\1", this) if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE):
return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}" this = escape_pattern.sub(escape_substitute, this)
return f"{left_quote}{this}{right_quote}{escape_sql}"
def rawstring_sql(self, expression: exp.RawString) -> str: def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False) string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False)
@ -1973,7 +1980,9 @@ class Generator(metaclass=_Generator):
return f", {this_sql}" return f", {this_sql}"
if op_sql != "STRAIGHT_JOIN":
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}" return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
@ -2235,10 +2244,6 @@ class Generator(metaclass=_Generator):
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
options = self.expressions(expression, key="options")
if options:
options = f" OPTION{self.wrap(options)}"
return csv( return csv(
*sqls, *sqls,
*[self.sql(join) for join in expression.args.get("joins") or []], *[self.sql(join) for join in expression.args.get("joins") or []],
@ -2253,10 +2258,14 @@ class Generator(metaclass=_Generator):
self.sql(expression, "order"), self.sql(expression, "order"),
*self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit), *self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit),
*self.after_limit_modifiers(expression), *self.after_limit_modifiers(expression),
options, self.options_modifier(expression),
sep="", sep="",
) )
def options_modifier(self, expression: exp.Expression) -> str:
options = self.expressions(expression, key="options")
return f" {options}" if options else ""
def queryoption_sql(self, expression: exp.QueryOption) -> str: def queryoption_sql(self, expression: exp.QueryOption) -> str:
return "" return ""

View file

@ -1034,7 +1034,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return ( return (
DATETRUNC_BINARY_COMPARISONS[comparison]( DATETRUNC_BINARY_COMPARISONS[comparison](
trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) trunc_arg, date, unit, dialect, extract_type(r)
) )
or expression or expression
) )
@ -1060,7 +1060,7 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr
return expression return expression
ranges = merge_ranges(ranges) ranges = merge_ranges(ranges)
target_type = extract_type(l, *rs) target_type = extract_type(*rs)
return exp.or_( return exp.or_(
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False

View file

@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
} }
JOIN_KINDS = { JOIN_KINDS = {
TokenType.ANTI,
TokenType.CROSS,
TokenType.INNER, TokenType.INNER,
TokenType.OUTER, TokenType.OUTER,
TokenType.CROSS,
TokenType.SEMI, TokenType.SEMI,
TokenType.ANTI, TokenType.STRAIGHT_JOIN,
} }
JOIN_HINTS: t.Set[str] = set() JOIN_HINTS: t.Set[str] = set()
@ -1065,7 +1066,7 @@ class Parser(metaclass=_Parser):
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this), exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
} }
TYPE_CONVERTER: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {} TYPE_CONVERTERS: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
@ -1138,7 +1139,14 @@ class Parser(metaclass=_Parser):
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} ADD_CONSTRAINT_TOKENS = {
TokenType.CONSTRAINT,
TokenType.FOREIGN_KEY,
TokenType.INDEX,
TokenType.KEY,
TokenType.PRIMARY_KEY,
TokenType.UNIQUE,
}
DISTINCT_TOKENS = {TokenType.DISTINCT} DISTINCT_TOKENS = {TokenType.DISTINCT}
@ -3099,7 +3107,7 @@ class Parser(metaclass=_Parser):
index = self._index index = self._index
method, side, kind = self._parse_join_parts() method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN) join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)
if not skip_join_token and not join: if not skip_join_token and not join:
self._retreat(index) self._retreat(index)
@ -3242,7 +3250,7 @@ class Parser(metaclass=_Parser):
while self._match_set(self.TABLE_INDEX_HINT_TOKENS): while self._match_set(self.TABLE_INDEX_HINT_TOKENS):
hint = exp.IndexTableHint(this=self._prev.text.upper()) hint = exp.IndexTableHint(this=self._prev.text.upper())
self._match_texts(("INDEX", "KEY")) self._match_set((TokenType.INDEX, TokenType.KEY))
if self._match(TokenType.FOR): if self._match(TokenType.FOR):
hint.set("target", self._advance_any() and self._prev.text.upper()) hint.set("target", self._advance_any() and self._prev.text.upper())
@ -4464,8 +4472,8 @@ class Parser(metaclass=_Parser):
) )
self._match(TokenType.R_BRACKET) self._match(TokenType.R_BRACKET)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type): if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this) converter = self.TYPE_CONVERTERS.get(this.this)
if converter: if converter:
this = converter(t.cast(exp.DataType, this)) this = converter(t.cast(exp.DataType, this))
@ -4496,7 +4504,12 @@ class Parser(metaclass=_Parser):
def _parse_column(self) -> t.Optional[exp.Expression]: def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_column_reference() this = self._parse_column_reference()
return self._parse_column_ops(this) if this else self._parse_bracket(this) column = self._parse_column_ops(this) if this else self._parse_bracket(this)
if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_column_reference(self) -> t.Optional[exp.Expression]: def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field() this = self._parse_field()
@ -4522,7 +4535,11 @@ class Parser(metaclass=_Parser):
while self._match(TokenType.COLON): while self._match(TokenType.COLON):
start_index = self._index start_index = self._index
path = self._parse_column_ops(self._parse_field(any_token=True))
# Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True
path = self._parse_column_ops(
self._parse_field(any_token=True, tokens=(TokenType.SELECT,))
)
# The cast :: operator has a lower precedence than the extraction operator :, so # The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the JSON path # we rearrange the AST appropriately to avoid casting the JSON path

View file

@ -287,6 +287,7 @@ class TokenType(AutoName):
JOIN = auto() JOIN = auto()
JOIN_MARKER = auto() JOIN_MARKER = auto()
KEEP = auto() KEEP = auto()
KEY = auto()
KILL = auto() KILL = auto()
LANGUAGE = auto() LANGUAGE = auto()
LATERAL = auto() LATERAL = auto()
@ -360,6 +361,7 @@ class TokenType(AutoName):
SORT_BY = auto() SORT_BY = auto()
START_WITH = auto() START_WITH = auto()
STORAGE_INTEGRATION = auto() STORAGE_INTEGRATION = auto()
STRAIGHT_JOIN = auto()
STRUCT = auto() STRUCT = auto()
TABLE_SAMPLE = auto() TABLE_SAMPLE = auto()
TAG = auto() TAG = auto()
@ -764,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME, "SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY, "SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH, "START WITH": TokenType.START_WITH,
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
"TABLE": TokenType.TABLE, "TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE, "TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY, "TEMP": TokenType.TEMPORARY,
@ -1270,18 +1273,6 @@ class Tokenizer(metaclass=_Tokenizer):
elif token_type == TokenType.BIT_STRING: elif token_type == TokenType.BIT_STRING:
base = 2 base = 2
elif token_type == TokenType.HEREDOC_STRING: elif token_type == TokenType.HEREDOC_STRING:
if (
self.HEREDOC_TAG_IS_IDENTIFIER
and not self._peek.isidentifier()
and not self._peek == end
):
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
self._add(self.HEREDOC_STRING_ALTERNATIVE)
else:
self._scan_var()
return True
self._advance() self._advance()
if self._char == end: if self._char == end:
@ -1293,7 +1284,10 @@ class Tokenizer(metaclass=_Tokenizer):
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER, raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
) )
if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER: if tag and self.HEREDOC_TAG_IS_IDENTIFIER and (self._end or not tag.isidentifier()):
if not self._end:
self._advance(-1)
self._advance(-len(tag)) self._advance(-len(tag))
self._add(self.HEREDOC_STRING_ALTERNATIVE) self._add(self.HEREDOC_STRING_ALTERNATIVE)
return True return True

View file

@ -505,7 +505,10 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression:
def _ensure_bool(node: exp.Expression) -> None: def _ensure_bool(node: exp.Expression) -> None:
if ( if (
node.is_number node.is_number
or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) or (
not isinstance(node, exp.SubqueryPredicate)
and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
)
or (isinstance(node, exp.Column) and not node.type) or (isinstance(node, exp.Column) and not node.type)
): ):
node.replace(node.neq(0)) node.replace(node.neq(0))

2
sqlglotrs/Cargo.lock generated
View file

@ -188,7 +188,7 @@ checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]] [[package]]
name = "sqlglotrs" name = "sqlglotrs"
version = "0.2.5" version = "0.2.6"
dependencies = [ dependencies = [
"pyo3", "pyo3",
] ]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "sqlglotrs" name = "sqlglotrs"
version = "0.2.5" version = "0.2.6"
edition = "2021" edition = "2021"
[lib] [lib]

View file

@ -405,19 +405,6 @@ impl<'a> TokenizerState<'a> {
} else if *token_type == self.token_types.bit_string { } else if *token_type == self.token_types.bit_string {
(Some(2), *token_type, end.clone()) (Some(2), *token_type, end.clone())
} else if *token_type == self.token_types.heredoc_string { } else if *token_type == self.token_types.heredoc_string {
if self.settings.heredoc_tag_is_identifier
&& !self.is_identifier(self.peek_char)
&& self.peek_char.to_string() != *end
{
if self.token_types.heredoc_string_alternative != self.token_types.var {
self.add(self.token_types.heredoc_string_alternative, None)?
} else {
self.scan_var()?
};
return Ok(true)
};
self.advance(1)?; self.advance(1)?;
let tag = if self.current_char.to_string() == *end { let tag = if self.current_char.to_string() == *end {
@ -426,7 +413,14 @@ impl<'a> TokenizerState<'a> {
self.extract_string(end, false, false, !self.settings.heredoc_tag_is_identifier)? self.extract_string(end, false, false, !self.settings.heredoc_tag_is_identifier)?
}; };
if self.is_end && !tag.is_empty() && self.settings.heredoc_tag_is_identifier { if !tag.is_empty()
&& self.settings.heredoc_tag_is_identifier
&& (self.is_end || !self.is_identifier(&tag))
{
if !self.is_end {
self.advance(-1)?;
}
self.advance(-(tag.len() as isize))?; self.advance(-(tag.len() as isize))?;
self.add(self.token_types.heredoc_string_alternative, None)?; self.add(self.token_types.heredoc_string_alternative, None)?;
return Ok(true) return Ok(true)
@ -494,7 +488,7 @@ impl<'a> TokenizerState<'a> {
} else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 { } else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 {
scientific += 1; scientific += 1;
self.advance(1)?; self.advance(1)?;
} else if self.is_identifier(self.peek_char) { } else if self.is_alphabetic_or_underscore(self.peek_char) {
let number_text = self.text(); let number_text = self.text();
let mut literal = String::from(""); let mut literal = String::from("");
@ -676,10 +670,18 @@ impl<'a> TokenizerState<'a> {
Ok(text) Ok(text)
} }
fn is_identifier(&mut self, name: char) -> bool { fn is_alphabetic_or_underscore(&mut self, name: char) -> bool {
name.is_alphabetic() || name == '_' name.is_alphabetic() || name == '_'
} }
fn is_identifier(&mut self, s: &str) -> bool {
s.chars().enumerate().all(
|(i, c)|
if i == 0 { self.is_alphabetic_or_underscore(c) }
else { self.is_alphabetic_or_underscore(c) || c.is_digit(10) }
)
}
fn extract_value(&mut self) -> Result<String, TokenizerError> { fn extract_value(&mut self) -> Result<String, TokenizerError> {
loop { loop {
if !self.peek_char.is_whitespace() if !self.peek_char.is_whitespace()

View file

@ -20,6 +20,14 @@ class TestBigQuery(Validator):
maxDiff = None maxDiff = None
def test_bigquery(self): def test_bigquery(self):
self.validate_all(
"EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
write={
"bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
"duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
"snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
},
)
self.validate_identity( self.validate_identity(
"""CREATE TEMPORARY FUNCTION FOO() """CREATE TEMPORARY FUNCTION FOO()
RETURNS STRING RETURNS STRING
@ -619,9 +627,9 @@ LANGUAGE js AS
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={ write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
}, },
) )
self.validate_all( self.validate_all(
@ -761,12 +769,15 @@ LANGUAGE js AS
"clickhouse": "SHA256(x)", "clickhouse": "SHA256(x)",
"presto": "SHA256(x)", "presto": "SHA256(x)",
"trino": "SHA256(x)", "trino": "SHA256(x)",
"postgres": "SHA256(x)",
}, },
write={ write={
"bigquery": "SHA256(x)", "bigquery": "SHA256(x)",
"spark2": "SHA2(x, 256)", "spark2": "SHA2(x, 256)",
"clickhouse": "SHA256(x)", "clickhouse": "SHA256(x)",
"postgres": "SHA256(x)",
"presto": "SHA256(x)", "presto": "SHA256(x)",
"redshift": "SHA2(x, 256)",
"trino": "SHA256(x)", "trino": "SHA256(x)",
}, },
) )

View file

@ -18,6 +18,13 @@ class TestDuckDB(Validator):
"WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1", "WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1",
) )
self.validate_all(
"SELECT straight_join",
write={
"duckdb": "SELECT straight_join",
"mysql": "SELECT `straight_join`",
},
)
self.validate_all( self.validate_all(
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)", "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
read={ read={
@ -278,6 +285,7 @@ class TestDuckDB(Validator):
self.validate_identity("FROM tbl", "SELECT * FROM tbl") self.validate_identity("FROM tbl", "SELECT * FROM tbl")
self.validate_identity("x -> '$.family'") self.validate_identity("x -> '$.family'")
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))") self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob")
self.validate_identity( self.validate_identity(
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE" "SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
) )
@ -1000,6 +1008,7 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)") self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)")
self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)") self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)")
self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)") self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)")
@ -1027,6 +1036,13 @@ class TestDuckDB(Validator):
"CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
) )
self.validate_all(
"CAST(x AS VARCHAR(5))",
write={
"duckdb": "CAST(x AS TEXT)",
"postgres": "CAST(x AS TEXT)",
},
)
self.validate_all( self.validate_all(
"CAST(x AS DECIMAL(38, 0))", "CAST(x AS DECIMAL(38, 0))",
read={ read={

View file

@ -21,6 +21,9 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE") self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)")
self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)")
self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT")
self.validate_identity( self.validate_identity(
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))" "CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
) )
@ -60,6 +63,10 @@ class TestMySQL(Validator):
self.validate_identity( self.validate_identity(
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q", "CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
) )
self.validate_identity(
"ALTER TABLE t ADD KEY `i` (`c`)",
"ALTER TABLE t ADD INDEX `i` (`c`)",
)
self.validate_identity( self.validate_identity(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
@ -76,9 +83,6 @@ class TestMySQL(Validator):
"ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT",
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
) )
self.validate_identity(
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity( self.validate_identity(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
@ -113,6 +117,7 @@ class TestMySQL(Validator):
) )
def test_identity(self): def test_identity(self):
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2") self.validate_identity("SELECT @var1 := 1, @var2")

View file

@ -1,5 +1,5 @@
from sqlglot import exp from sqlglot import exp, UnsupportedError
from sqlglot.errors import UnsupportedError from sqlglot.dialects.oracle import eliminate_join_marks
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
@ -43,6 +43,7 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)") self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)")
self.validate_identity("SELECT * FROM V$SESSION") self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')")
self.validate_identity( self.validate_identity(
"SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name" "SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name"
) )
@ -249,7 +250,8 @@ class TestOracle(Validator):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_all( self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError} "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
write={"": UnsupportedError},
) )
self.validate_all( self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
@ -413,3 +415,65 @@ WHERE
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"): for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True) self.validate_identity(query, pretty, pretty=True)
def test_eliminate_join_marks(self):
test_sql = [
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
),
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
),
(
"SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
),
(
"SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
),
(
"SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
),
# 2 join marks on one side of predicate
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
),
# join mark and expression
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
),
]
for original, expected in test_sql:
with self.subTest(original):
self.assertEqual(
eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
expected,
)
def test_query_restrictions(self):
for restriction in ("READ ONLY", "CHECK OPTION"):
for constraint_name in (" CONSTRAINT name", ""):
with self.subTest(f"Restriction: {restriction}"):
self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}")
self.validate_identity(
f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
)

View file

@ -8,6 +8,7 @@ class TestPostgres(Validator):
dialect = "postgres" dialect = "postgres"
def test_postgres(self): def test_postgres(self):
self.validate_identity("SHA384(x)")
self.validate_identity( self.validate_identity(
'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)" 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
) )
@ -724,6 +725,28 @@ class TestPostgres(Validator):
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)") self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)") self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")
self.validate_all(
"1 / DIV(4, 2)",
read={
"postgres": "1 / DIV(4, 2)",
},
write={
"sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
"duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
"bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
},
)
self.validate_all(
"CAST(DIV(4, 2) AS DECIMAL(5, 3))",
read={
"duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
},
write={
"duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
"postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
},
)
def test_ddl(self): def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier # Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is( self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(

View file

@ -564,6 +564,7 @@ class TestPresto(Validator):
self.validate_all( self.validate_all(
f"{prefix}'Hello winter \\2603 !'", f"{prefix}'Hello winter \\2603 !'",
write={ write={
"oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter \\2603 !'", "presto": "U&'Hello winter \\2603 !'",
"snowflake": "'Hello winter \\u2603 !'", "snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'", "spark": "'Hello winter \\u2603 !'",
@ -572,6 +573,7 @@ class TestPresto(Validator):
self.validate_all( self.validate_all(
f"{prefix}'Hello winter #2603 !' UESCAPE '#'", f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
write={ write={
"oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter #2603 !' UESCAPE '#'", "presto": "U&'Hello winter #2603 !' UESCAPE '#'",
"snowflake": "'Hello winter \\u2603 !'", "snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'", "spark": "'Hello winter \\u2603 !'",

View file

@ -281,6 +281,9 @@ class TestRedshift(Validator):
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')", "redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))", "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))", "tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
"spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
"spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
"databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
}, },
) )
self.validate_all( self.validate_all(
@ -585,3 +588,9 @@ FROM (
self.assertEqual( self.assertEqual(
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l" ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
) )
def test_join_markers(self):
self.validate_identity(
"select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)",
"SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)",
)

View file

@ -125,6 +125,10 @@ WHERE
"SELECT a:from::STRING, a:from || ' test' ", "SELECT a:from::STRING, a:from || ' test' ",
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'", "SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
) )
self.validate_identity(
"SELECT a:select",
"SELECT GET_PATH(a, 'select')",
)
self.validate_identity("x:from", "GET_PATH(x, 'from')") self.validate_identity("x:from", "GET_PATH(x, 'from')")
self.validate_identity( self.validate_identity(
"value:values::string::int", "value:values::string::int",
@ -1196,16 +1200,16 @@ WHERE
for constraint_prefix in ("WITH ", ""): for constraint_prefix in ("WITH ", ""):
with self.subTest(f"Constraint prefix: {constraint_prefix}"): with self.subTest(f"Constraint prefix: {constraint_prefix}"):
self.validate_identity( self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p)", f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)",
"CREATE TABLE t (id INT MASKING POLICY p)", "CREATE TABLE t (id INT MASKING POLICY p.q.r)",
) )
self.validate_identity( self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))", f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))",
"CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))", "CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))",
) )
self.validate_identity( self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p)", f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)",
"CREATE TABLE t (id INT PROJECTION POLICY p)", "CREATE TABLE t (id INT PROJECTION POLICY p.q.r)",
) )
self.validate_identity( self.validate_identity(
f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))", f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))",

View file

@ -563,6 +563,7 @@ TBLPROPERTIES (
"SELECT DATE_ADD(my_date_column, 1)", "SELECT DATE_ADD(my_date_column, 1)",
write={ write={
"spark": "SELECT DATE_ADD(my_date_column, 1)", "spark": "SELECT DATE_ADD(my_date_column, 1)",
"spark2": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)", "bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
}, },
) )
@ -675,6 +676,16 @@ TBLPROPERTIES (
"spark": "SELECT ARRAY_SORT(x)", "spark": "SELECT ARRAY_SORT(x)",
}, },
) )
self.validate_all(
"SELECT DATE_ADD(MONTH, 20, col)",
read={
"spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
},
write={
"spark": "SELECT DATE_ADD(MONTH, 20, col)",
"databricks": "SELECT DATE_ADD(MONTH, 20, col)",
},
)
def test_bool_or(self): def test_bool_or(self):
self.validate_all( self.validate_all(

View file

@ -202,6 +202,7 @@ class TestSQLite(Validator):
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
read={ read={
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
"postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
}, },
write={ write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",

View file

@ -1,12 +1,18 @@
from sqlglot import exp, parse, parse_one from sqlglot import exp, parse
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError from sqlglot.errors import ParseError
from sqlglot.optimizer.annotate_types import annotate_types
class TestTSQL(Validator): class TestTSQL(Validator):
dialect = "tsql" dialect = "tsql"
def test_tsql(self): def test_tsql(self):
self.assertEqual(
annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
"SELECT 1 WHERE EXISTS(SELECT 1)",
)
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)") self.validate_identity("ROUND(x, 1, 0)")
@ -217,9 +223,9 @@ class TestTSQL(Validator):
"CREATE TABLE [db].[tbl] ([a] INTEGER)", "CREATE TABLE [db].[tbl] ([a] INTEGER)",
) )
projection = parse_one("SELECT a = 1", read="tsql").selects[0] self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
projection.assert_is(exp.Alias) exp.Alias
projection.args["alias"].assert_is(exp.Identifier) ).args["alias"].assert_is(exp.Identifier)
self.validate_all( self.validate_all(
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName", "IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
@ -756,12 +762,9 @@ class TestTSQL(Validator):
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"): for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x") self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql") self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
self.assertIsInstance(expression, exp.AlterTable) exp.AlterTable
self.assertIsInstance(expression.args["actions"][0], exp.Drop) ).args["actions"][0].assert_is(exp.Drop)
self.assertEqual(
expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
)
for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"): for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity( self.validate_identity(
@ -795,10 +798,10 @@ class TestTSQL(Validator):
) )
self.validate_all( self.validate_all(
"CREATE TABLE [#temptest] (name VARCHAR)", "CREATE TABLE [#temptest] (name INTEGER)",
read={ read={
"duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)", "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
"tsql": "CREATE TABLE [#temptest] (name VARCHAR)", "tsql": "CREATE TABLE [#temptest] (name INTEGER)",
}, },
) )
self.validate_all( self.validate_all(
@ -1632,27 +1635,23 @@ WHERE
) )
def test_identifier_prefixes(self): def test_identifier_prefixes(self):
expr = parse_one("#x", read="tsql") self.assertTrue(
self.assertIsInstance(expr, exp.Column) self.validate_identity("#x")
self.assertIsInstance(expr.this, exp.Identifier) .assert_is(exp.Column)
self.assertTrue(expr.this.args.get("temporary")) .this.assert_is(exp.Identifier)
self.assertEqual(expr.sql("tsql"), "#x") .args.get("temporary")
)
self.assertTrue(
self.validate_identity("##x")
.assert_is(exp.Column)
.this.assert_is(exp.Identifier)
.args.get("global")
)
expr = parse_one("##x", read="tsql") self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
self.assertIsInstance(expr, exp.Column) self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
self.assertIsInstance(expr.this, exp.Identifier) exp.Table
self.assertTrue(expr.this.args.get("global")) ).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
self.assertEqual(expr.sql("tsql"), "##x")
expr = parse_one("@x", read="tsql")
self.assertIsInstance(expr, exp.Parameter)
self.assertIsInstance(expr.this, exp.Var)
self.assertEqual(expr.sql("tsql"), "@x")
table = parse_one("select * from @x", read="tsql").args["from"].this
self.assertIsInstance(table, exp.Table)
self.assertIsInstance(table.this, exp.Parameter)
self.assertIsInstance(table.this.this, exp.Var)
self.validate_all( self.validate_all(
"SELECT @x", "SELECT @x",
@ -1663,8 +1662,6 @@ WHERE
"tsql": "SELECT @x", "tsql": "SELECT @x",
}, },
) )
def test_temp_table(self):
self.validate_all( self.validate_all(
"SELECT * FROM #mytemptable", "SELECT * FROM #mytemptable",
write={ write={

View file

@ -872,3 +872,4 @@ SELECT name
SELECT copy SELECT copy
SELECT rollup SELECT rollup
SELECT unnest SELECT unnest
SELECT * FROM a STRAIGHT_JOIN b

View file

@ -1047,6 +1047,9 @@ x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE);
TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME); TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME);
x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME); x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);
DATE_TRUNC('day', CAST(x AS DATE)) <= CAST('2021-01-01 01:02:03' AS TIMESTAMP);
CAST(x AS DATE) < CAST('2021-01-02 01:02:03' AS TIMESTAMP);
-------------------------------------- --------------------------------------
-- EQUALITY -- EQUALITY
-------------------------------------- --------------------------------------

View file

@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):
def qualify_columns(expression, **kwargs): def qualify_columns(expression, **kwargs):
expression = optimizer.qualify.qualify( expression = optimizer.qualify.qualify(
expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs expression,
infer_schema=True,
validate_qualify_columns=False,
identify=False,
**kwargs,
) )
return expression return expression
@ -111,7 +115,14 @@ class TestOptimizer(unittest.TestCase):
} }
def check_file( def check_file(
self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs self,
file,
func,
pretty=False,
execute=False,
set_dialect=False,
only=None,
**kwargs,
): ):
with ProcessPoolExecutor() as pool: with ProcessPoolExecutor() as pool:
results = {} results = {}
@ -331,7 +342,11 @@ class TestOptimizer(unittest.TestCase):
) )
self.check_file( self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True "qualify_columns",
qualify_columns,
execute=True,
schema=self.schema,
set_dialect=True,
) )
self.check_file( self.check_file(
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True "qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
@ -343,7 +358,8 @@ class TestOptimizer(unittest.TestCase):
def test_pushdown_cte_alias_columns(self): def test_pushdown_cte_alias_columns(self):
self.check_file( self.check_file(
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns "pushdown_cte_alias_columns",
optimizer.qualify_columns.pushdown_cte_alias_columns,
) )
def test_qualify_columns__invalid(self): def test_qualify_columns__invalid(self):
@ -405,7 +421,8 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy())) self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
anon_unquoted_identifier = exp.Anonymous( anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")] this=exp.to_identifier("anonymous"),
expressions=[exp.column("x"), exp.column("y")],
) )
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)") self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
@ -416,7 +433,10 @@ class TestOptimizer(unittest.TestCase):
anon_invalid = exp.Anonymous(this=5) anon_invalid = exp.Anonymous(this=5)
optimizer.simplify.gen(anon_invalid) optimizer.simplify.gen(anon_invalid)
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception)) self.assertIn(
"Anonymous.this expects a str or an Identifier, got 'int'.",
str(e.exception),
)
sql = parse_one( sql = parse_one(
""" """
@ -906,7 +926,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip( for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] cte_select.find_all(exp.Subquery),
[exp.DataType.Type.CHAR, exp.DataType.Type.TEXT],
): ):
self.assertEqual(d.this.expressions[0].this.type.this, t) self.assertEqual(d.this.expressions[0].this.type.this, t)
@ -1020,7 +1041,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for (func, col), target_type in tests.items(): for (func, col), target_type in tests.items():
expression = annotate_types( expression = annotate_types(
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"),
schema=schema,
) )
self.assertEqual(expression.expressions[0].type.this, target_type) self.assertEqual(expression.expressions[0].type.this, target_type)
@ -1035,7 +1057,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this) self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
def test_nested_type_annotation(self): def test_nested_type_annotation(self):
schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}} schema = {
"order": {
"customer_id": "bigint",
"item_id": "bigint",
"item_price": "numeric",
}
}
sql = """ sql = """
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items, SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
FROM order AS order FROM order AS order
@ -1057,7 +1085,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>") self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
self.assertEqual( self.assertEqual(
expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>" expression.selects[1].type.sql(dialect="bigquery"),
"ARRAY<STRUCT<`f` STRING>>",
) )
expression = annotate_types( expression = annotate_types(
@ -1206,7 +1235,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual( self.assertEqual(
optimizer.optimize( optimizer.optimize(
parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery") parse_one("SELECT * FROM a"),
schema=MappingSchema(schema, dialect="bigquery"),
), ),
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'), parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
) )

View file

@ -106,6 +106,7 @@ class TestParser(unittest.TestCase):
expr = parse_one("SELECT foo IN UNNEST(bla) AS bar") expr = parse_one("SELECT foo IN UNNEST(bla) AS bar")
self.assertIsInstance(expr.selects[0], exp.Alias) self.assertIsInstance(expr.selects[0], exp.Alias)
self.assertEqual(expr.selects[0].output_name, "bar") self.assertEqual(expr.selects[0].output_name, "bar")
self.assertIsNotNone(parse_one("select unnest(x)").find(exp.Unnest))
def test_unary_plus(self): def test_unary_plus(self):
self.assertEqual(parse_one("+15"), exp.Literal.number(15)) self.assertEqual(parse_one("+15"), exp.Literal.number(15))
@ -880,10 +881,12 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or) self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)
def test_trailing_comments(self): def test_trailing_comments(self):
expressions = parse(""" expressions = parse(
"""
select * from x; select * from x;
-- my comment -- my comment
""") """
)
self.assertEqual( self.assertEqual(
";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */" ";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */"