1
0
Fork 0

Adding upstream version 6.0.4.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 06:15:54 +01:00
parent d01130b3f1
commit 527597d2af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
122 changed files with 23162 additions and 0 deletions

26
.github/workflows/python-package.yml vendored Normal file
View file

@ -0,0 +1,26 @@
name: Test and Lint Python Package
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
- name: Run checks (linter, code style, tests)
run: |
./run_checks.sh

27
.github/workflows/python-publish.yml vendored Normal file
View file

@ -0,0 +1,27 @@
name: Publish Python Release to PyPI
on:
push:
tags:
- "v*"
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*

132
.gitignore vendored Normal file
View file

@ -0,0 +1,132 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# PyCharm
.idea/

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.linting.pylintEnabled": true
}

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Toby Mao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

330
README.md Normal file
View file

@ -0,0 +1,330 @@
# SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python.
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations.
## Install
From PyPI
```
pip3 install sqlglot
```
Or with a local checkout
```
pip3 install -e .
```
## Examples
Easily translate from one dialect to another. For example, date/time functions vary from dialects and can be hard to deal with.
```python
import sqlglot
sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive')
```
```sql
SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC')
```
SQLGlot can even translate custom time formats.
```python
import sqlglot
sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read='duckdb', write='hive')
```
```sql
SELECT DATE_FORMAT(x, 'yy-M-ss')"
```
## Formatting and Transpiling
Read in a SQL statement with a CTE and CASTING to a REAL and then transpiling to Spark.
Spark uses backticks as identifiers and the REAL type is transpiled to FLOAT.
```python
import sqlglot
sql = """WITH baz AS (SELECT a, c FROM foo WHERE a = 1) SELECT f.a, b.b, baz.c, CAST("b"."a" AS REAL) d FROM foo f JOIN bar b ON f.a = b.a LEFT JOIN baz ON f.a = baz.a"""
sqlglot.transpile(sql, write='spark', identify=True, pretty=True)[0]
```
```sql
WITH `baz` AS (
SELECT
`a`,
`c`
FROM `foo`
WHERE
`a` = 1
)
SELECT
`f`.`a`,
`b`.`b`,
`baz`.`c`,
CAST(`b`.`a` AS FLOAT) AS `d`
FROM `foo` AS `f`
JOIN `bar` AS `b`
ON `f`.`a` = `b`.`a`
LEFT JOIN `baz`
ON `f`.`a` = `baz`.`a`
```
## Metadata
You can explore SQL with expression helpers to do things like find columns and tables.
```python
from sqlglot import parse_one, exp
# print all column references (a and b)
for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column):
print(column.alias_or_name)
# find all projections in select statements (a and c)
for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select):
for projection in select.expressions:
print(projection.alias_or_name)
# find all tables (x, y, z)
for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):
print(table.name)
```
## Parser Errors
A syntax error will result in a parser error.
```python
transpile("SELECT foo( FROM bar")
```
sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13.
select foo( __FROM__ bar
## Unsupported Errors
Presto APPROX_DISTINCT supports the accuracy argument which is not supported in Spark.
```python
transpile(
'SELECT APPROX_DISTINCT(a, 0.1) FROM foo',
read='presto',
write='spark',
)
```
```sql
WARNING:root:APPROX_COUNT_DISTINCT does not support accuracy
SELECT APPROX_COUNT_DISTINCT(a) FROM foo
```
## Build and Modify SQL
SQLGlot supports incrementally building sql expressions.
```python
from sqlglot import select, condition
where = condition("x=1").and_("y=1")
select("*").from_("y").where(where).sql()
```
Which outputs:
```sql
SELECT * FROM y WHERE x = 1 AND y = 1
```
You can also modify a parsed tree:
```python
from sqlglot import parse_one
parse_one("SELECT x FROM y").from_("z").sql()
```
Which outputs:
```sql
SELECT x FROM y, z
```
There is also a way to recursively transform the parsed tree by applying a mapping function to each tree node:
```python
from sqlglot import exp, parse_one
expression_tree = parse_one("SELECT a FROM x")
def transformer(node):
if isinstance(node, exp.Column) and node.name == "a":
return parse_one("FUN(a)")
return node
transformed_tree = expression_tree.transform(transformer)
transformed_tree.sql()
```
Which outputs:
```sql
SELECT FUN(a) FROM x
```
## SQL Optimizer
SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine.
```python
import sqlglot
from sqlglot.optimizer import optimize
>>>
optimize(
sqlglot.parse_one("""
SELECT A OR (B OR (C AND D))
FROM x
WHERE Z = date '2021-01-01' + INTERVAL '1' month OR 1 = 0
"""),
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
).sql(pretty=True)
"""
SELECT
(
"x"."A"
OR "x"."B"
OR "x"."C"
)
AND (
"x"."A"
OR "x"."B"
OR "x"."D"
) AS "_col_0"
FROM "x" AS "x"
WHERE
"x"."Z" = CAST('2021-02-01' AS DATE)
"""
```
## SQL Annotations
SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example:
```sql
SELECT
user #primary_key,
country
FROM users
```
SQL annotations are currently incompatible with MySQL, which uses the `#` character to introduce comments.
## AST Introspection
You can see the AST version of the sql by calling repr.
```python
from sqlglot import parse_one
repr(parse_one("SELECT a + 1 AS z"))
(SELECT expressions:
(ALIAS this:
(ADD this:
(COLUMN this:
(IDENTIFIER this: a, quoted: False)), expression:
(LITERAL this: 1, is_string: False)), alias:
(IDENTIFIER this: z, quoted: False)))
```
## AST Diff
SQLGlot can calculate the difference between two expressions and output changes in a form of a sequence of actions needed to transform a source expression into a target one.
```python
from sqlglot import diff, parse_one
diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d"))
[
Remove(expression=(ADD this:
(COLUMN this:
(IDENTIFIER this: a, quoted: False)), expression:
(COLUMN this:
(IDENTIFIER this: b, quoted: False)))),
Insert(expression=(SUB this:
(COLUMN this:
(IDENTIFIER this: a, quoted: False)), expression:
(COLUMN this:
(IDENTIFIER this: b, quoted: False)))),
Move(expression=(COLUMN this:
(IDENTIFIER this: c, quoted: False))),
Keep(source=(IDENTIFIER this: b, quoted: False), target=(IDENTIFIER this: b, quoted: False)),
...
]
```
## Custom Dialects
[Dialects](sqlglot/dialects) can be added by subclassing Dialect.
```python
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.tokens import Tokenizer, TokenType
class Custom(Dialect):
class Tokenizer(Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
}
class Generator(Generator):
TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
TYPE_MAPPING = {
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.TEXT: "STRING",
}
Dialects["custom"]
```
## Benchmarks
[Benchmarks](benchmarks) run on Python 3.10.5 in seconds.
| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) |
| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) |
| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) |
| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) |
## Run Tests and Lint
```
pip install -r requirements.txt
./run_checks.sh
```
## Optional Dependencies
SQLGlot uses [dateutil](https://github.com/dateutil/dateutil) to simplify literal timedelta expressions. The optimizer will not simplify expressions like
```sql
x + interval '1' month
```
if the module cannot be found.

225
benchmarks/bench.py Normal file
View file

@ -0,0 +1,225 @@
import collections.abc
# moz_sql_parser 3.10 compatibility
collections.Iterable = collections.abc.Iterable
import gc
import timeit
import moz_sql_parser
import numpy as np
import sqloxide
import sqlparse
import sqltree
import sqlglot
long = """
SELECT
"e"."employee_id" AS "Employee #",
"e"."first_name" || ' ' || "e"."last_name" AS "Name",
"e"."email" AS "Email",
"e"."phone_number" AS "Phone",
TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date",
TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary",
"e"."commission_pct" AS "Comission %",
'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job",
TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary",
"l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location",
"jh"."job_id" AS "History Job ID",
'worked from ' || TO_CHAR("jh"."start_date", 'MM/DD/YYYY') || ' to ' || TO_CHAR("jh"."end_date", 'MM/DD/YYYY') || ' as ' || "jj"."job_title" || ' in ' || "dd"."department_name" || ' department' AS "History Job Title",
case when 1 then 1 when 2 then 2 when 3 then 3 when 4 then 4 when 5 then 5 else a(b(c + 1 * 3 % 4)) end
FROM "employees" AS e
JOIN "jobs" AS j
ON "e"."job_id" = "j"."job_id"
LEFT JOIN "employees" AS m
ON "e"."manager_id" = "m"."employee_id"
LEFT JOIN "departments" AS d
ON "d"."department_id" = "e"."department_id"
LEFT JOIN "employees" AS dm
ON "d"."manager_id" = "dm"."employee_id"
LEFT JOIN "locations" AS l
ON "d"."location_id" = "l"."location_id"
LEFT JOIN "countries" AS c
ON "l"."country_id" = "c"."country_id"
LEFT JOIN "regions" AS r
ON "c"."region_id" = "r"."region_id"
LEFT JOIN "job_history" AS jh
ON "e"."employee_id" = "jh"."employee_id"
LEFT JOIN "jobs" AS jj
ON "jj"."job_id" = "jh"."job_id"
LEFT JOIN "departments" AS dd
ON "dd"."department_id" = "jh"."department_id"
ORDER BY
"e"."employee_id"
"""
short = "select 1 as a, case when 1 then 1 when 2 then 2 else 3 end as b, c from x"
crazy = "SELECT 1+"
crazy += "+".join(str(i) for i in range(500))
crazy += " AS a, 2*"
crazy += "*".join(str(i) for i in range(500))
crazy += " AS b FROM x"
tpch = """
WITH "_e_0" AS (
SELECT
"partsupp"."ps_partkey" AS "ps_partkey",
"partsupp"."ps_suppkey" AS "ps_suppkey",
"partsupp"."ps_supplycost" AS "ps_supplycost"
FROM "partsupp" AS "partsupp"
), "_e_1" AS (
SELECT
"region"."r_regionkey" AS "r_regionkey",
"region"."r_name" AS "r_name"
FROM "region" AS "region"
WHERE
"region"."r_name" = 'EUROPE'
)
SELECT
"supplier"."s_acctbal" AS "s_acctbal",
"supplier"."s_name" AS "s_name",
"nation"."n_name" AS "n_name",
"part"."p_partkey" AS "p_partkey",
"part"."p_mfgr" AS "p_mfgr",
"supplier"."s_address" AS "s_address",
"supplier"."s_phone" AS "s_phone",
"supplier"."s_comment" AS "s_comment"
FROM (
SELECT
"part"."p_partkey" AS "p_partkey",
"part"."p_mfgr" AS "p_mfgr",
"part"."p_type" AS "p_type",
"part"."p_size" AS "p_size"
FROM "part" AS "part"
WHERE
"part"."p_size" = 15
AND "part"."p_type" LIKE '%BRASS'
) AS "part"
LEFT JOIN (
SELECT
MIN("partsupp"."ps_supplycost") AS "_col_0",
"partsupp"."ps_partkey" AS "_u_1"
FROM "_e_0" AS "partsupp"
CROSS JOIN "_e_1" AS "region"
JOIN (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN (
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_nationkey" AS "s_nationkey"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
GROUP BY
"partsupp"."ps_partkey"
) AS "_u_0"
ON "part"."p_partkey" = "_u_0"."_u_1"
CROSS JOIN "_e_1" AS "region"
JOIN (
SELECT
"nation"."n_nationkey" AS "n_nationkey",
"nation"."n_name" AS "n_name",
"nation"."n_regionkey" AS "n_regionkey"
FROM "nation" AS "nation"
) AS "nation"
ON "nation"."n_regionkey" = "region"."r_regionkey"
JOIN "_e_0" AS "partsupp"
ON "part"."p_partkey" = "partsupp"."ps_partkey"
JOIN (
SELECT
"supplier"."s_suppkey" AS "s_suppkey",
"supplier"."s_name" AS "s_name",
"supplier"."s_address" AS "s_address",
"supplier"."s_nationkey" AS "s_nationkey",
"supplier"."s_phone" AS "s_phone",
"supplier"."s_acctbal" AS "s_acctbal",
"supplier"."s_comment" AS "s_comment"
FROM "supplier" AS "supplier"
) AS "supplier"
ON "supplier"."s_nationkey" = "nation"."n_nationkey"
AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey"
WHERE
"partsupp"."ps_supplycost" = "_u_0"."_col_0"
AND NOT "_u_0"."_u_1" IS NULL
ORDER BY
"supplier"."s_acctbal" DESC,
"nation"."n_name",
"supplier"."s_name",
"part"."p_partkey"
LIMIT 100
"""
def sqlglot_parse(sql):
sqlglot.parse(sql, error_level=sqlglot.ErrorLevel.IGNORE)
def sqltree_parse(sql):
sqltree.api.sqltree(sql.replace('"', '`').replace("''", '"'))
def sqlparse_parse(sql):
sqlparse.parse(sql)
def moz_sql_parser_parse(sql):
moz_sql_parser.parse(sql)
def sqloxide_parse(sql):
sqloxide.parse_sql(sql, dialect="ansi")
def border(columns):
columns = " | ".join(columns)
return f"| {columns} |"
def diff(row, column):
if column == "Query":
return ""
column = row[column]
if isinstance(column, str):
return " (N/A)"
return f" ({str(column / row['sqlglot'])[0:5]})"
libs = [
"sqlglot",
"sqltree",
"sqlparse",
"moz_sql_parser",
"sqloxide",
]
table = []
for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.items():
row = {"Query": name}
table.append(row)
for lib in libs:
try:
row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3))
except:
row[lib] = "error"
columns = ["Query"] + libs
widths = {column: max(len(column), 15) for column in columns}
lines = [border(column.rjust(width) for column, width in widths.items())]
lines.append(border(str("-" * width) for width in widths.values()))
for i, row in enumerate(table):
lines.append(border(
(str(row[column])[0:7] + diff(row, column)).rjust(width)[0 : width]
for column, width in widths.items()
))
for line in lines:
print(line)

389
posts/sql_diff.md Normal file
View file

@ -0,0 +1,389 @@
# Semantic Diff for SQL
*by [Iaroslav Zeigerman](https://github.com/izeigerman)*
## Motivation
Software is constantly changing and evolving, and identifying what has changed and reviewing those changes is an integral part of the development process. SQL code is no exception to this.
Text-based diff tools such as `git diff`, when applied to a code base, have certain limitations. First, they can only detect insertions and deletions, not movements or updates of individual pieces of code. Second, such tools can only detect changes between lines of text, which is too coarse for something as granular and detailed as source code. Additionally, the outcome of such a diff is dependent on the underlying code formatting, and yields different results if the formatting should change.
Consider the following diff generated by Git:
![Git diff output](sql_diff_images/git_diff_output.png)
Semantically the query hasnt changed. The two arguments `b` and `c` have been swapped (moved), posing no impact on the output of the query. Yet Git replaced the whole affected expression alongside a bulk of unrelated elements.
The alternative to text-based diffing is to compare Abstract Syntax Trees (AST) instead. The main advantage of ASTs are that they are a direct product of code parsing, which represents the underlying code structure at any desired level of granularity. Comparing ASTs may yield extremely precise diffs; changes such as code movements and updates can also be detected. Even more importantly, this approach facilitates additional use cases beyond eyeballing two versions of source code side by side.
The use cases I had in mind for SQL when I decided to embark on this journey of semantic diffing were the following:
* **Query similarity score.** Identifying which parts the two queries have in common to automatically suggest opportunities for consolidation, creation of intermediate/staging tables, and so on.
* **Differentiating between cosmetic / structural changes and functional ones.** For example when a nested query is refactored into a common table expression (CTE), this kind of change doesnt have any functional impact on either a query or its outcome.
* **Automatic suggestions about the need to retroactively backfill data.** This is especially important for pipelines that populate very large tables for which restatement is a runtime-intensive procedure. The ability to discern between simple code movements and actual modifications can help assess the impact of a change and make suggestions accordingly.
The implementation discussed in this post is now a part of the [SQLGlot](https://github.com/tobymao/sqlglot/) library. You can find a complete source code in the [diff.py](https://github.com/tobymao/sqlglot/blob/main/sqlglot/diff.py) module. The choice of SQLglot was an obvious one due to its simple but powerful API, lack of external dependencies and, more importantly, extensive list of supported SQL dialects.
## The Search for a Solution
When it comes to any diffing tool (not just a semantic one), the primary challenge is to match as many elements of compared entities as possible. Once such a set of matching elements is available, deriving a sequence of changes becomes an easy task.
If our elements have unique identifiers associated with them (for example, an elements ID in DOM), the matching problem is trivial. However, the SQL syntax trees that we are comparing have neither unique keys nor object identifiers that can be used for the purposes of matching. So, how do we suppose to find pairs of nodes that are related?
To better illustrate the problem, consider comparing the following SQL expressions: `SELECT a + b + c, d, e` and `SELECT a - b + c, e, f`. Matching individual nodes from respective syntax trees can be visualized as follows:
![Figure 1: Example of node matching for two SQL expression trees](sql_diff_images/figure_1.png)
*Figure 1: Example of node matching for two SQL expression trees.*
By looking at the figure of node matching for two SQL expression trees above, we conclude that the following changes should be captured by our solution:
* Inserted nodes: `Sub` and `f`. These are the nodes from the target AST which do not have a matching node in the source AST.
* Removed nodes: `Add` and `d`. These are the nodes from the source AST which do not have a counterpart in the target AST.
* Remaining nodes must be identified as unchanged.
It should be clear at this point that if we manage to match nodes in the source tree with their counterparts in the target tree, then computing the diff becomes a trivial matter.
### Naïve Brute-Force
The naïve solution would be to try all different permutations of node pair combinations, and see which set of pairs performs the best based on some type of heuristics. The runtime cost of such a solution quickly reaches the escape velocity; if both trees had only 10 nodes each, the number of such sets would approximately be 10! ^ 2 = 3.6M ^ 2 ~= 13 * 10^12. This is a very bad case of factorial complexity (to be precise, its actually much worse - O(n! ^ 2) - but I couldnt come up with a name for it), so there is little need to explore this approach any further.
### Myers Algorithm
After the naïve approach was proven to be infeasible, the next question I asked myself was “how does git diff work?”. This question led me to discover the Myers diff algorithm [1]. This algorithm has been designed to compare sequences of strings. At its core, its looking for the shortest path on a graph of possible edits that transform the first sequence into the second one, while heavily rewarding those paths that lead to longest subsequences of unchanged elements. Theres a lot of material out there describing this algorithm in greater detail. I found James Coglans series of [blog posts](https://blog.jcoglan.com/2017/02/12/the-myers-diff-algorithm-part-1/) to be the most comprehensive.
Therefore, I had this “brilliant” (actually not) idea to transform trees into sequences by traversing them in topological order, and then applying the Myers algorithm on resulting sequences while using a custom heuristics when checking the equality of two nodes. Unsurprisingly, comparing sequences of strings is quite different from comparing hierarchical tree structures, and by flattening trees into sequences, we lose a lot of relevant context. This resulted in a terrible performance of this algorithm on ASTs. It often matched completely unrelated nodes, even when the two trees were mostly the same, and produced extremely inaccurate lists of changes overall. After playing around with it a little and tweaking my equality heuristics to improve accuracy, I ultimately scrapped the whole implementation and went back to the drawing board.
## Change Distiller
The algorithm I settled on at the end was Change Distiller, created by Fluri et al. [2], which in turn is an improvement over the core idea described by Chawathe et al. [3].
The algorithm consists of two high-level steps:
1. **Finding appropriate matchings between pairs of nodes that are part of compared ASTs.** Identifying what is meant by “appropriate” matching is also a part of this step.
2. **Generating the so-called “edit script” from the matching set built in the 1st step.** The edit script is a sequence of edit operations (for example, insert, remove, update, etc.) on individual tree nodes, such that when applied as transformations on the source AST, it eventually becomes the target AST. In general, the shorter the sequence, the better. The length of the edit script can be used to compare the performance of different algorithms, though this is not the only metric that matters.
The rest of this section is dedicated to the Python implementation of the steps above using the AST implementation provided by the SQLGlot library.
### Building the Matching Set
#### Matching Leaves
We begin composing the matching set by matching the leaf nodes. Leaf nodes are the nodes that do not have any children nodes (such as literals, identifiers, etc.). In order to match them, we gather all the leaf nodes from the source tree and generate a cartesian product with all the leaves from the target tree, while comparing pairs created this way and assigning them a similarity score. During this stage, we also exclude pairs that dont pass basic matching criteria. Then, we pick pairs that scored the highest while making sure that each node is matched no more than once.
Using the example provided at the beginning of the post, the process of building an initial set of candidate matchings can be seen on Figure 2.
![Figure 2: Building a set of candidate matchings between leaf nodes. The third item in each triplet represents a similarity score between two nodes.](sql_diff_images/figure_2.gif)
*Figure 2: Building a set of candidate matchings between leaf nodes. The third item in each triplet represents a similarity score between two nodes.*
First, lets analyze the similarity score. Then, well discuss matching criteria.
The similarity score proposed by Fluri et al. [2] is a [dice coefficient ](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)applied to [bigrams](https://en.wikipedia.org/wiki/Bigram) of respective node values. A bigram is a sequence of two adjacent elements from a string computed in a sliding window fashion:
```python
def bigram(string):
count = max(0, len(string) - 1)
return [string[i : i + 2] for i in range(count)]
```
For reasons that will become clear shortly, we actually need to compute bigram histograms rather than just sequences:
```python
from collections import defaultdict
def bigram_histo(string):
count = max(0, len(string) - 1)
bigram_histo = defaultdict(int)
for i in range(count):
bigram_histo[string[i : i + 2]] += 1
return bigram_histo
```
The dice coefficient formula looks like following:
![Dice Coefficient](sql_diff_images/dice_coef.png)
Where X is a bigram of the source node and Y is a bigram of the second one. What this essentially does is count the number of bigram elements the two nodes have in common, multiply it by 2, and then divide by the total number of elements in both bigrams. This is where bigram histograms come in handy:
```python
def dice_coefficient(source, target):
source_histo = bigram_histo(source.sql())
target_histo = bigram_histo(target.sql())
total_grams = (
sum(source_histo.values()) + sum(target_histo.values())
)
if not total_grams:
return 1.0 if source == target else 0.0
overlap_len = 0
overlapping_grams = set(source_histo) & set(target_histo)
for g in overlapping_grams:
overlap_len += min(source_histo[g], target_histo[g])
return 2 * overlap_len / total_grams
```
To compute a bigram given a tree node, we first transform the node into its canonical SQL representation,so that the `Literal(123)` node becomes just “123” and the `Identifier(“a”)` node becomes just “a”. We also handle a scenario when strings are too short to derive bigrams. In this case, we fallback to checking the two nodes for equality.
Now when we know how to compute the similarity score, we can take care of the matching criteria for leaf nodes. In the original paper [2], the matching criteria is formalized as follows:
![Matching criteria for leaf nodes](sql_diff_images/matching_criteria_1.png)
The two nodes are matched if two conditions are met:
1. The node labels match (in our case labels are just node types).
2. The similarity score for node values is greater than or equal to some threshold “f”. The authors of the paper recommend setting the value of “f” to 0.6.
With building blocks in place, we can now build a matching set for leaf nodes. First, we generate a list of candidates for matching:
```python
from heapq import heappush, heappop
candidate_matchings = []
source_leaves = _get_leaves(self._source)
target_leaves = _get_leaves(self._target)
for source_leaf in source_leaves:
for target_leaf in target_leaves:
if _is_same_type(source_leaf, target_leaf):
similarity_score = dice_coefficient(
source_leaf, target_leaf
)
if similarity_score >= 0.6:
heappush(
candidate_matchings,
(
-similarity_score,
len(candidate_matchings),
source_leaf,
target_leaf,
),
)
```
In the implementation above, we push each matching pair onto the heap to automatically maintain the correct order based on the assigned similarity score.
Finally, we build the initial matching set by picking leaf pairs with the highest score:
```python
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
if (
source_leaf in unmatched_source_nodes
and target_leaf in unmatched_target_nodes
):
matching_set.add((source_leaf, target_leaf))
unmatched_source_nodes.remove(source_leaf)
unmatched_target_nodes.remove(target_leaf)
```
To finalize the matching set, we should now proceed with matching inner nodes.
#### Matching Inner Nodes
Matching inner nodes is quite similar to matching leaf nodes, with the following two distinctions:
* Rather than ranking a set of possible candidates, we pick the first node pair that passes the matching criteria.
* The matching criteria itself has been extended to account for the number of leaf nodes the pair of inner nodes have in common.
![Figure 3: Matching inner nodes based on their type as well as how many of their leaf nodes have been previously matched.](sql_diff_images/figure_3.gif)
*Figure 3: Matching inner nodes based on their type as well as how many of their leaf nodes have been previously matched.*
Lets start with the matching criteria. The criteria is formalized as follows:
![Matching criteria for inner nodes](sql_diff_images/matching_criteria_2.png)
Alongside already familiar similarity score and node type criteria, there is a new one in the middle: the ratio of leaf nodes that the two nodes have in common must exceed some threshold “t”. The recommended value for “t” is also 0.6. Counting the number of common leaf nodes is pretty straightforward, since we already have the complete matching set for leaves. All we need to do is count how many matching pairs do leaf nodes from the two compared inner nodes form.
There are two additional heuristics associated with this matching criteria:
* Inner node similarity weighting: if the similarity score between the node values doesnt pass the threshold “f” but the ratio of common leaf nodes (“t”) is greater than or equal to 0.8, then the matching is considered successful.
* The threshold “t” is reduced to 0.4 for inner nodes with the number of leaf nodes equal to 4 or less, in order to decrease the false negative rate for small subtrees.
We now only have to iterate through the remaining unmatched nodes and form matching pairs based on the outlined criteria:
```python
leaves_matching_set = matching_set.copy()
for source_node in unmatched_source_nodes.copy():
for target_node in unmatched_target_nodes:
if _is_same_type(source_node, target_node):
source_leaves = set(_get_leaves(source_node))
target_leaves = set(_get_leaves(target_node))
max_leaves_num = max(len(source_leaves), len(target_leaves))
if max_leaves_num:
common_leaves_num = sum(
1 if s in source_leaves and t in target_leaves else 0
for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
adjusted_t = (
0.6
if min(len(source_leaves), len(target_leaves)) > 4
else 0.4
)
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
and dice_coefficient(source_node, target_node) >= 0.6
):
matching_set.add((source_node, target_node))
unmatched_source_nodes.remove(source_node)
unmatched_target_nodes.remove(target_node)
break
```
After the matching set is formed, we can proceed with generation of the edit script, which will be the algorithms output.
### Generating the Edit Script
At this point, we should have the following 3 sets at our disposal:
* The set of matched node pairs.
* The set of remaining unmatched nodes from the source tree.
* The set of remaining unmatched nodes from the target tree.
We can derive 3 kinds of edits from the matching set: either the nodes value was updated (**Update**), the node was moved to a different position within the tree (**Move**), or the node remained unchanged (**Keep**). Note that the **Move** case is not mutually exclusive with the other two. The node could have been updated or could have remained the same while at the same time its position within its parent node or the parent node itself could have changed. All unmatched nodes from the source tree are the ones that were removed (**Remove**), while unmatched nodes from the target tree are the ones that were inserted (**Insert**).
The latter two cases are pretty straightforward to implement:
```python
edit_script = []
for removed_node in unmatched_source_nodes:
edit_script.append(Remove(removed_node))
for inserted_node in unmatched_target_nodes:
edit_script.append(Insert(inserted_node))
```
Traversing the matching set requires a little more thought:
```python
for source_node, target_node in matching_set:
if (
not isinstance(source_node, LEAF_EXPRESSION_TYPES)
or source_node == target_node
):
move_edits = generate_move_edits(
source_node, target_node, matching_set
)
edit_script.extend(move_edits)
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
```
If a matching pair represents a pair of leaf nodes, we check if they are the same to decide whether an update took place. For inner node pairs, we also need to compare the positions of their respective children to detect node movements. Chawathe et al. [3] suggest applying the [longest common subsequence ](https://en.wikipedia.org/wiki/Longest_common_subsequence_problem)(LCS) algorithm which, no surprise here, was described by Myers himself [1]. There is a small catch, however: instead of checking the equality of two children nodes, we need to check whether the two nodes form a pair that is a part of our matching set.
Now with this knowledge, the implementation becomes straightforward:
```python
def generate_move_edits(source, target, matching_set):
source_children = _get_child_nodes(source)
target_children = _get_child_nodes(target)
lcs = set(
_longest_common_subsequence(
source_children,
target_children,
lambda l, r: (l, r) in matching_set
)
)
move_edits = []
for node in source_children:
if node not in lcs and node not in unmatched_source_nodes:
move_edits.append(Move(node))
return move_edits
```
I left out the implementation of the LCS algorithm itself here, but there are plenty of implementation choices out there that can be easily looked up.
### Output
The implemented algorithm produces the output that resembles the following:
```python
>>> from sqlglot import parse_one, diff
>>> diff(parse_one("SELECT a + b + c, d, e"), parse_one("SELECT a - b + c, e, f"))
Remove(Add)
Remove(Column(d))
Remove(Identifier(d))
Insert(Sub)
Insert(Column(f))
Insert(Identifier(f))
Keep(Select, Select)
Keep(Add, Add)
Keep(Column(a), Column(a))
Keep(Identifier(a), Identifier(a))
Keep(Column(b), Column(b))
Keep(Identifier(b), Identifier(b))
Keep(Column(c), Column(c))
Keep(Identifier(c), Identifier(c))
Keep(Column(e), Column(e))
Keep(Identifier(e), Identifier(e))
```
Note that the output above is abbreviated. The string representation of actual AST nodes is significantly more verbose.
The implementation works especially well when coupled with the SQLGlots query optimizer which can be used to produce canonical representations of compared queries:
```python
>>> schema={"t": {"a": "INT", "b": "INT", "c": "INT", "d": "INT"}}
>>> source = """
... SELECT 1 + 1 + a
... FROM t
... WHERE b = 1 OR (c = 2 AND d = 3)
... """
>>> target = """
... SELECT 2 + a
... FROM t
... WHERE (b = 1 OR c = 2) AND (b = 1 OR d = 3)
... """
>>> optimized_source = optimize(parse_one(source), schema=schema)
>>> optimized_target = optimize(parse_one(target), schema=schema)
>>> edit_script = diff(optimized_source, optimized_target)
>>> sum(0 if isinstance(e, Keep) else 1 for e in edit_script)
0
```
### Optimizations
The worst case runtime complexity of this algorithm is not exactly stellar: O(n^2 * log n^2). This is because of the leaf matching process, which involves ranking a cartesian product between all leaf nodes of compared trees. Unsurprisingly, the algorithm takes a considerable time to finish for bigger queries.
There are still a few basic things we can do in our implementation to help improve performance:
* Refer to individual node objects using their identifiers (Pythons [id()](https://docs.python.org/3/library/functions.html#id)) instead of direct references in sets. This helps avoid costly recursive hash calculations and equality checks.
* Cache bigram histograms to avoid computing them more than once for the same node.
* Compute the canonical SQL string representation for each tree once while caching string representations of all inner nodes. This prevents redundant tree traversals when bigrams are computed.
At the time of writing only the first two optimizations have been implemented, so there is an opportunity to contribute for anyone whos interested.
## Alternative Solutions
This section is dedicated to solutions that Ive investigated, but havent tried.
First, this section wouldnt be complete without Tristan Humes [blog post](https://thume.ca/2017/06/17/tree-diffing/). Tristans solution has a lot in common with the Myers algorithm plus heuristics that is much more clever than what I came up with. The implementation relies on a combination of [dynamic programming](https://en.wikipedia.org/wiki/Dynamic_programming) and [A* search algorithm](https://en.wikipedia.org/wiki/A*_search_algorithm) to explore the space of possible matchings and pick the best ones. It seemed to have worked well for Tistans specific use case, but after my negative experience with the Myers algorithm, I decided to try something different.
Another notable approach is the Gumtree algorithm by Falleri et al. [4]. I discovered this paper after Id already implemented the algorithm that is the main focus of this post. In sections 5.2 and 5.3 of their paper, the authors compare the two algorithms side by side and claim that Gumtree is significantly better in terms of both runtime performance and accuracy when evaluated on 12 792 pairs of Java source files. This doesnt surprise me, as the algorithm takes the height of subtrees into account. In my tests, I definitely saw scenarios in which this context would have helped. On top of that, the authors promise O(n^2) runtime complexity in the worst case which, given the Change Distiller's O(n^2 * log n^2), looks particularly tempting. I hope to try this algorithm out at some point, and there is a good chance you see me writing about it in my future posts.
## Conclusion
The Change Distiller algorithm yielded quite satisfactory results in most of my tests. The scenarios in which it fell short mostly concerned identical (or very similar) subtrees located in different parts of the AST. In those cases, node mismatches were frequent and, as a result, edit scripts were somewhat suboptimal.
Additionally, the runtime performance of the algorithm leaves a lot to be desired. On trees with 1000 leaf nodes each, the algorithm takes a little under 2 seconds to complete. My implementation still has room for improvement, but this should give you a rough idea of what to expect. It appears that the Gumtree algorithm [4] can help address both of these points. I hope to find bandwidth to work on it soon and then compare the two algorithms side-by-side to find out which one performs better on SQL specifically. In the meantime, Change Distiller definitely gets the job done, and I can now proceed with applying it to some of the use cases I mentioned at the beginning of this post.
Im also curious to learn whether other folks in the industry faced a similar problem, and how they approached it. If you did something similar, Im interested to hear about your experience.
## References
[1] Eugene W. Myers. [An O(ND) Difference Algorithm and Its Variations](http://www.xmailserver.org/diff2.pdf). Algorithmica 1(2): 251-266 (1986)
[2] B. Fluri, M. Wursch, M. Pinzger, and H. Gall. [Change Distilling: Tree differencing for fine-grained source code change extraction](https://www.researchgate.net/publication/3189787_Change_DistillingTree_Differencing_for_Fine-Grained_Source_Code_Change_Extraction). IEEE Trans. Software Eng., 33(11):725743, 2007.
[3] S.S. Chawathe, A. Rajaraman, H. Garcia-Molina, and J. Widom. [Change Detection in Hierarchically Structured Information](http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf). Proc. ACM Sigmod Intl Conf. Management of Data, pp. 493-504, June 1996
[4] Jean-Rémy Falleri, Floréal Morandat, Xavier Blanc, Matias Martinez, Martin Monperrus. [Fine-grained and Accurate Source Code Differencing](https://hal.archives-ouvertes.fr/hal-01054552/document). Proceedings of the International Conference on Automated Software Engineering, 2014, Västeras, Sweden. pp.313-324, 10.1145/2642937.2642982. hal-01054552

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 320 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

6
requirements.txt Normal file
View file

@ -0,0 +1,6 @@
autoflake
black
duckdb
isort
pandas
python-dateutil

12
run_checks.sh Executable file
View file

@ -0,0 +1,12 @@
#!/bin/bash -e
python -m autoflake -i -r \
--expand-star-imports \
--remove-all-unused-imports \
--ignore-init-module-imports \
--remove-duplicate-keys \
--remove-unused-variables \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
python -m black sqlglot/ tests/
python -m unittest

33
setup.py Normal file
View file

@ -0,0 +1,33 @@
from setuptools import find_packages, setup
version = (
open("sqlglot/__init__.py")
.read()
.split("__version__ = ")[-1]
.split("\n")[0]
.strip("")
.strip("'")
.strip('"')
)
setup(
name="sqlglot",
version=version,
description="An easily customizable SQL parser and transpiler",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/tobymao/sqlglot",
author="Toby Mao",
author_email="toby.mao@gmail.com",
license="MIT",
packages=find_packages(include=["sqlglot", "sqlglot.*"]),
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: SQL",
"Programming Language :: Python :: 3 :: Only",
],
)

96
sqlglot/__init__.py Normal file
View file

@ -0,0 +1,96 @@
from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError
from sqlglot.expressions import Expression
from sqlglot.expressions import alias_ as alias
from sqlglot.expressions import (
and_,
column,
condition,
from_,
maybe_parse,
not_,
or_,
select,
subquery,
)
from sqlglot.expressions import table_ as table
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.0.4"
pretty = False
def parse(sql, read=None, **opts):
"""
Parses the given SQL string into a collection of syntax trees, one per
parsed SQL statement.
Args:
sql (str): the SQL code string to parse.
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
**opts: other options.
Returns:
typing.List[Expression]: the list of parsed syntax trees.
"""
dialect = Dialect.get_or_raise(read)()
return dialect.parse(sql, **opts)
def parse_one(sql, read=None, into=None, **opts):
"""
Parses the given SQL string and returns a syntax tree for the first
parsed SQL statement.
Args:
sql (str): the SQL code string to parse.
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
into (Expression): the SQLGlot Expression to parse into
**opts: other options.
Returns:
Expression: the syntax tree for the first parsed statement.
"""
dialect = Dialect.get_or_raise(read)()
if into:
result = dialect.parse_into(into, sql, **opts)
else:
result = dialect.parse(sql, **opts)
return result[0] if result else None
def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
"""
Parses the given SQL string using the source dialect and returns a list of SQL strings
transformed to conform to the target dialect. Each string in the returned list represents
a single transformed SQL statement.
Args:
sql (str): the SQL code string to transpile.
read (str): the source dialect used to parse the input string
(eg. "spark", "hive", "presto", "mysql").
write (str): the target dialect into which the input should be transformed
(eg. "spark", "hive", "presto", "mysql").
identity (bool): if set to True and if the target dialect is not specified
the source dialect will be used as both: the source and the target dialect.
error_level (ErrorLevel): the desired error level of the parser.
**opts: other options.
Returns:
typing.List[str]: the list of transpiled SQL statements / expressions.
"""
write = write or read if identity else write
return [
Dialect.get_or_raise(write)().generate(expression, **opts)
for expression in parse(sql, read, error_level=error_level)
]

69
sqlglot/__main__.py Normal file
View file

@ -0,0 +1,69 @@
import argparse
import sqlglot
parser = argparse.ArgumentParser(description="Transpile SQL")
parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
parser.add_argument(
"--read",
dest="read",
type=str,
default=None,
help="Dialect to read default is generic",
)
parser.add_argument(
"--write",
dest="write",
type=str,
default=None,
help="Dialect to write default is generic",
)
parser.add_argument(
"--no-identify",
dest="identify",
action="store_false",
help="Don't auto identify fields",
)
parser.add_argument(
"--no-pretty",
dest="pretty",
action="store_false",
help="Compress sql",
)
parser.add_argument(
"--parse",
dest="parse",
action="store_true",
help="Parse and return the expression tree",
)
parser.add_argument(
"--error-level",
dest="error_level",
type=str,
default="RAISE",
help="IGNORE, WARN, RAISE (default)",
)
args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse:
sqls = [
repr(expression)
for expression in sqlglot.parse(
args.sql, read=args.read, error_level=error_level
)
]
else:
sqls = sqlglot.transpile(
args.sql,
read=args.read,
write=args.write,
identify=args.identify,
pretty=args.pretty,
error_level=error_level,
)
for sql in sqls:
print(sql)

View file

@ -0,0 +1,15 @@
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
from sqlglot.dialects.trino import Trino

View file

@ -0,0 +1,128 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
return expression_class(
this=list_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
return func
def _date_add_sql(data_type, kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = self.sql(expression, "unit") or "'day'"
expression = self.sql(expression, "expression")
return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
return func
class BigQuery(Dialect):
unnest_column_only = True
class Tokenizer(Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
KEYWORDS = {
**Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
}
NO_PAREN_FUNCTIONS = {
**Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
class Generator(Generator):
TRANSFORMS = {
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.ILike: no_ilike_sql,
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.NVARCHAR: "STRING",
}
def in_unnest_op(self, unnest):
return self.sql(unnest)
def union_op(self, expression):
return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def intersect_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported(
"INTERSECT without DISTINCT is not supported in BigQuery"
)
return (
f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
)

View file

@ -0,0 +1,48 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', "`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"NULLABLE": TokenType.NULLABLE,
"FINAL": TokenType.FINAL,
"INT8": TokenType.TINYINT,
"INT16": TokenType.SMALLINT,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
}
class Parser(Parser):
def _parse_table(self, schema=False):
this = super()._parse_table(schema)
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
return this
class Generator(Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
}

268
sqlglot/dialects/dialect.py Normal file
View file

@ -0,0 +1,268 @@
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import csv, list_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
from sqlglot.trie import new_trie
class Dialects(str, Enum):
DIALECT = ""
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DUCKDB = "duckdb"
HIVE = "hive"
MYSQL = "mysql"
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
TRINO = "trino"
class _Dialect(type):
classes = {}
@classmethod
def __getitem__(cls, key):
return cls.classes[key]
@classmethod
def get(cls, key, default=None):
return cls.classes.get(key, default)
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
enum = Dialects.__members__.get(clsname.upper())
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
klass.time_trie = new_trie(klass.time_mapping)
klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.tokenizer = klass.tokenizer_class()
klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
0
]
klass.identifier_start, klass.identifier_end = list(
klass.tokenizer_class.IDENTIFIERS.items()
)[0]
return klass
class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
normalize_functions = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
time_mapping = {}
# autofilled
quote_start = None
quote_end = None
identifier_start = None
identifier_end = None
time_trie = None
inverse_time_mapping = None
inverse_time_trie = None
tokenizer_class = None
parser_class = None
generator_class = None
tokenizer = None
@classmethod
def get_or_raise(cls, dialect):
if not dialect:
return cls
result = cls.get(dialect)
if not result:
raise ValueError(f"Unknown dialect '{dialect}'")
return result
@classmethod
def format_time(cls, expression):
if isinstance(expression, str):
return exp.Literal.string(
format_time(
expression[1:-1], # the time formats are quoted
cls.time_mapping,
cls.time_trie,
)
)
if expression and expression.is_string:
return exp.Literal.string(
format_time(
expression.this,
cls.time_mapping,
cls.time_trie,
)
)
return expression
def parse(self, sql, **opts):
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts):
return self.parser(**opts).parse_into(
expression_type, self.tokenizer.tokenize(sql), sql
)
def generate(self, expression, **opts):
return self.generator(**opts).generate(expression)
def transpile(self, code, **opts):
return self.generate(self.parse(code), **opts)
def parser(self, **opts):
return self.parser_class(
**{
"index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"null_ordering": self.null_ordering,
**opts,
},
)
def generator(self, **opts):
return self.generator_class(
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPE,
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
**opts,
}
)
def rename_func(name):
return (
lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
)
def approx_count_distinct_sql(self, expression):
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})"
def if_sql(self, expression):
expressions = csv(
self.sql(expression, "this"),
self.sql(expression, "true"),
self.sql(expression, "false"),
)
return f"IF({expressions})"
def arrow_json_extract_sql(self, expression):
return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
def arrow_json_extract_scalar_sql(self, expression):
return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
def inline_array_sql(self, expression):
return f"[{self.expressions(expression)}]"
def no_ilike_sql(self, expression):
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this),
expression=expression.args["expression"],
)
)
def no_paren_current_date_sql(self, expression):
zone = self.sql(expression, "this")
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql(self, expression):
if expression.args.get("recursive"):
self.unsupported("Recursive CTEs are unsupported")
expression.args["recursive"] = False
return self.with_sql(expression)
def no_safe_divide_sql(self, expression):
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql(self, expression):
self.unsupported("TABLESAMPLE unsupported")
return self.sql(expression.this)
def no_trycast_sql(self, expression):
return self.cast_sql(expression)
def str_position_sql(self, expression):
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
if position:
return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
return f"STRPOS({this}, {substr})"
def struct_extract_sql(self, expression):
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
return f"{this}.{struct_key}"
def format_time_lambda(exp_class, dialect, default=None):
"""Helper used for time expressions.
Args
exp_class (Class): the expression class to instantiate
dialect (string): sql dialect
default (Option[bool | str]): the default format, True being time
"""
def _format_time(args):
return exp_class(
this=list_get(args, 0),
format=Dialect[dialect].format_time(
list_get(args, 1)
or (Dialect[dialect].time_format if default is True else default)
),
)
return _format_time

156
sqlglot/dialects/duckdb.py Normal file
View file

@ -0,0 +1,156 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
no_safe_divide_sql,
no_tablesample_sql,
rename_func,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _unix_to_time(self, expression):
return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_add(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + INTERVAL {e} {unit}"
def _array_sort_sql(self, expression):
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.FALSE:
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression):
args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})"
class DuckDB(Dialect):
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
":=": TokenType.EQ,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
this=list_get(args, 0),
expression=exp.Literal.number(1000),
)
),
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRUCT_PACK": exp.Struct.from_arg_list,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})",
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
exp.Split: rename_func("STR_SPLIT"),
exp.SortArray: _sort_array_sql,
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_pack_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}

304
sqlglot/dialects/hive.py Normal file
View file

@ -0,0 +1,304 @@
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
format_time_lambda,
if_sql,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
struct_extract_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import csv, list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer
def _parse_map(args):
keys = []
values = []
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
return HiveMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
def _map_sql(self, expression):
keys = expression.args["keys"]
values = expression.args["values"]
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
self.unsupported("Cannot convert array columns into map use SparkSQL instead.")
return f"MAP({self.sql(keys)}, {self.sql(values)})"
args = []
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return f"MAP({csv(*args)})"
def _array_sort(self, expression):
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}' = {value}"
def _str_to_unix(self, expression):
return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})"
def _str_to_date(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS DATE)"
def _str_to_time(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS TIMESTAMP)"
def _time_format(self, expression):
time_format = self.format_time(expression)
if time_format == Hive.time_format:
return None
return time_format
def _time_to_str(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
def _to_date_sql(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.time_format, Hive.date_format):
return f"TO_DATE({this}, {time_format})"
return f"TO_DATE({this})"
def _unnest_to_explode_sql(self, expression):
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
return "".join(
self.sql(
exp.Lateral(
this=udtf(this=expression),
alias=exp.TableAlias(this=alias.this, columns=[column]),
)
)
for expression, column in zip(
unnest.expressions, alias.columns if alias else []
)
)
return self.join_sql(expression)
def _index_sql(self, expression):
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON TABLE {table} {columns}"
class HiveMap(exp.Map):
is_var_len_args = True
class Hive(Dialect):
alias_post_tablesample = True
time_mapping = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
"yy": "%y",
"MMMM": "%B",
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dd": "%d",
"d": "%-d",
"HH": "%H",
"H": "%-H",
"hh": "%I",
"h": "%-I",
"mm": "%M",
"m": "%-M",
"ss": "%S",
"s": "%-S",
"S": "%f",
}
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
ESCAPE = "\\"
ENCODE = "utf-8"
NUMERIC_LITERALS = {
"L": "BIGINT",
"S": "SMALLINT",
"Y": "TINYINT",
"D": "DOUBLE",
"F": "FLOAT",
"BD": "DECIMAL",
}
class Parser(Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
expression=list_get(args, 1),
unit=exp.Literal.string("DAY"),
),
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=list_get(args, 0)),
expression=exp.TsOrDsToDate(this=list_get(args, 1)),
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
expression=exp.Mul(
this=list_get(args, 1),
expression=exp.Literal.number(-1),
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
this=list_get(args, 1),
substr=list_get(args, 0),
position=list_get(args, 2),
),
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"MAP": _parse_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
"COLLECT_SET": exp.SetAgg.from_arg_list,
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
class Generator(Generator):
ROOT_PROPERTIES = [
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
]
WITH_PROPERTIES = [exp.AnonymousProperty]
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
exp.With: no_recursive_cte_sql,
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql,
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Quantile: rename_func("PERCENTILE"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}",
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
}
def with_properties(self, properties):
return self.properties(
properties,
prefix="TBLPROPERTIES",
)
def datatype_sql(self, expression):
if (
expression.this
in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions
):
expression = exp.DataType.build("text")
return super().datatype_sql(expression)

163
sqlglot/dialects/mysql.py Normal file
View file

@ -0,0 +1,163 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
Dialect,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _date_trunc_sql(self, expression):
unit = expression.text("unit").lower()
this = self.sql(expression.this)
if unit == "day":
return f"DATE({this})"
if unit == "week":
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
concat = f"CONCAT(YEAR({this}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
return f"DATE({this})"
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
date_format = MySQL.format_time(list_get(args, 1))
return exp.StrToDate(this=list_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
return expression_class(
this=list_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
return func
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
return func
class MySQL(Dialect):
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
time_mapping = {
"%M": "%B",
"%c": "%-m",
"%e": "%-d",
"%h": "%I",
"%i": "%M",
"%s": "%S",
"%S": "%S",
"%u": "%W",
}
class Tokenizer(Tokenizer):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
"_BINARY": TokenType.INTRODUCER,
"_CP1250": TokenType.INTRODUCER,
"_CP1251": TokenType.INTRODUCER,
"_CP1256": TokenType.INTRODUCER,
"_CP1257": TokenType.INTRODUCER,
"_CP850": TokenType.INTRODUCER,
"_CP852": TokenType.INTRODUCER,
"_CP866": TokenType.INTRODUCER,
"_CP932": TokenType.INTRODUCER,
"_DEC8": TokenType.INTRODUCER,
"_EUCJPMS": TokenType.INTRODUCER,
"_EUCKR": TokenType.INTRODUCER,
"_GB18030": TokenType.INTRODUCER,
"_GB2312": TokenType.INTRODUCER,
"_GBK": TokenType.INTRODUCER,
"_GEOSTD8": TokenType.INTRODUCER,
"_GREEK": TokenType.INTRODUCER,
"_HEBREW": TokenType.INTRODUCER,
"_HP8": TokenType.INTRODUCER,
"_KEYBCS2": TokenType.INTRODUCER,
"_KOI8R": TokenType.INTRODUCER,
"_KOI8U": TokenType.INTRODUCER,
"_LATIN1": TokenType.INTRODUCER,
"_LATIN2": TokenType.INTRODUCER,
"_LATIN5": TokenType.INTRODUCER,
"_LATIN7": TokenType.INTRODUCER,
"_MACCE": TokenType.INTRODUCER,
"_MACROMAN": TokenType.INTRODUCER,
"_SJIS": TokenType.INTRODUCER,
"_SWE7": TokenType.INTRODUCER,
"_TIS620": TokenType.INTRODUCER,
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
}
class Parser(Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
}
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateSub: _date_add_sql("SUB"),
exp.DateTrunc: _date_trunc_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
}

View file

@ -0,0 +1,63 @@
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot.generator import Generator
from sqlglot.helper import csv
from sqlglot.tokens import Tokenizer, TokenType
def _limit_sql(self, expression):
return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
class Oracle(Dialect):
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
exp.DataType.Type.BIGINT: "NUMBER",
exp.DataType.Type.DECIMAL: "NUMBER",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.VARCHAR: "VARCHAR2",
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
}
def query_modifiers(self, expression, *sqls):
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
*[self.sql(sql) for sql in expression.args.get("joins", [])],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.sql(expression, "window"),
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
self.sql(expression, "offset"), # offset before limit in oracle
self.sql(expression, "limit"),
sep="",
)
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
}

View file

@ -0,0 +1,109 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _date_add_sql(kind):
def func(self, expression):
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
unit = self.sql(expression, "unit")
expression = simplify(expression.args["expression"])
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")
expression = expression.copy()
expression.args["is_string"] = True
expression = self.sql(expression)
return f"{this} {kind} INTERVAL {expression} {unit}"
return func
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
"AM": "%p", # AM or PM
"D": "%w", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
"FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres
"FMDDD": "%-j", # day of year
"FMHH12": "%-I", # 9
"FMHH24": "%-H", # 9
"FMMI": "%-M", # Minute
"FMMM": "%-m", # 1
"FMSS": "%-S", # Second
"HH12": "%I", # 09
"HH24": "%H", # 09
"MI": "%M", # zero padded minute
"MM": "%m", # 01
"OF": "%z", # utc offset
"SS": "%S", # zero padded second
"TMDay": "%A", # TM is locale dependent
"TMDy": "%a",
"TMMon": "%b", # Sep
"TMMonth": "%B", # September
"TZ": "%Z", # uppercase timezone name
"US": "%f", # zero padded microsecond
"WW": "%U", # 1-based week of year
"YY": "%y", # 15
"YYYY": "%Y", # 2015
}
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"SERIAL": TokenType.AUTO_INCREMENT,
"UUID": TokenType.UUID,
}
class Parser(Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
}
TOKEN_MAPPING = {
TokenType.AUTO_INCREMENT: "SERIAL",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}

216
sqlglot/dialects/presto.py Normal file
View file

@ -0,0 +1,216 @@
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
if_sql,
no_ilike_sql,
no_safe_divide_sql,
rename_func,
str_position_sql,
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.generator import Generator
from sqlglot.helper import csv, list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _approx_distinct_sql(self, expression):
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _concat_ws_sql(self, expression):
sep, *args = expression.expressions
sep = self.sql(sep)
if len(args) > 1:
return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})"
return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
def _datatype_sql(self, expression):
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
sql = f"{sql} WITH TIME ZONE"
return sql
def _date_parse_sql(self, expression):
return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
def _explode_to_unnest_sql(self, expression):
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
ordinality=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)
def _initcap_sql(self, expression):
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _no_sort_array(self, expression):
if expression.args.get("asc") == exp.FALSE:
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
args = csv(self.sql(expression, "this"), comparator)
return f"ARRAY_SORT({args})"
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
if isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(schema.expressions)
return self.schema_sql(expression)
def _quantile_sql(self, expression):
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(self, expression):
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return (
f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
)
def _ts_or_ds_add_sql(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit") or "'day'"
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
time_mapping = MySQL.time_mapping
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"ROW": TokenType.STRUCT,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
}
class Generator(Generator):
STRUCT_DELIMITER = ("(", ")")
WITH_PROPERTIES = [
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
]
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.FileFormatProperty: lambda self, e: self.property_sql(e),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
exp.TimeStrToDate: _date_parse_sql,
exp.TimeStrToTime: _date_parse_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}

View file

@ -0,0 +1,145 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func
from sqlglot.expressions import Literal
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _check_int(s):
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args):
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
return format_time_lambda(exp.StrToTime, "snowflake")(args)
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
elif second_arg.name == "3":
timescale = exp.UnixToTime.MILLIS
elif second_arg.name == "9":
timescale = exp.UnixToTime.MICROS
return exp.UnixToTime(this=first_arg, scale=timescale)
first_arg = list_get(args, 0)
if not isinstance(first_arg, Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
if first_arg.is_string:
if _check_int(first_arg.this):
# case: <integer>
return exp.UnixToTime.from_arg_list(args)
# case: <date_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
# case: <numeric_expr>
return exp.UnixToTime.from_arg_list(args)
def _unix_to_time(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
return f"TO_TIMESTAMP({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TO_TIMESTAMP({timestamp}, 3)"
if scale == exp.UnixToTime.MICROS:
return f"TO_TIMESTAMP({timestamp}, 9)"
raise ValueError("Improper scale for timestamp")
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
time_mapping = {
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
"yy": "%y",
"MMMM": "%B",
"mmmm": "%B",
"MON": "%b",
"mon": "%b",
"MM": "%m",
"mm": "%m",
"DD": "%d",
"dd": "%d",
"d": "%-d",
"DY": "%w",
"dy": "%w",
"HH24": "%H",
"hh24": "%H",
"HH12": "%I",
"hh12": "%I",
"MI": "%M",
"mi": "%M",
"SS": "%S",
"ss": "%S",
"FF": "%f",
"ff": "%f",
"FF6": "%f",
"ff6": "%f",
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
expressions=[path],
),
}
class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
KEYWORDS = {
**Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
}
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
def intersect_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)

106
sqlglot/dialects/spark.py Normal file
View file

@ -0,0 +1,106 @@
from sqlglot import exp
from sqlglot.dialects.dialect import no_ilike_sql, rename_func
from sqlglot.dialects.hive import Hive, HiveMap
from sqlglot.helper import list_get
def _create_sql(self, e):
kind = e.args.get("kind")
temporary = e.args.get("temporary")
if kind.upper() == "TABLE" and temporary is True:
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return self.create_sql(e)
def _map_sql(self, expression):
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _str_to_date(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
def _unix_to_time(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
return f"FROM_UNIXTIME({timestamp})"
if scale == exp.UnixToTime.SECONDS:
return f"TIMESTAMP_SECONDS({timestamp})"
if scale == exp.UnixToTime.MILLIS:
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
raise ValueError("Improper scale for timestamp")
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=list_get(args, 0),
start=exp.Literal.number(1),
length=list_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=list_get(args, 0),
expression=list_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=list_get(args, 0),
expression=list_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=list_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
expression=exp.Add(
this=list_get(args, 1), expression=exp.Literal.number(1)
),
),
length=list_get(args, 1),
),
}
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
**{
k: v
for k, v in Hive.Generator.TRANSFORMS.items()
if k not in {exp.ArraySort}
},
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.Create: _create_sql,
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
HiveMap: _map_sql,
}
def bitstring_sql(self, expression):
return f"X'{self.sql(expression, 'this')}'"

View file

@ -0,0 +1,63 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
no_ilike_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.BIGINT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "REAL",
exp.DataType.Type.DECIMAL: "REAL",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.NCHAR: "TEXT",
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB",
}
TOKEN_MAPPING = {
TokenType.AUTO_INCREMENT: "AUTOINCREMENT",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}

View file

@ -0,0 +1,12 @@
from sqlglot import exp
from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator):
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}

View file

@ -0,0 +1,37 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
def _if_sql(self, expression):
return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
def _coalesce_sql(self, expression):
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
def _count_sql(self, expression):
this = expression.this
if isinstance(this, exp.Distinct):
return f"COUNTD({self.sql(this, 'this')})"
return f"COUNT({self.sql(expression, 'this')})"
class Tableau(Dialect):
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))),
}

10
sqlglot/dialects/trino.py Normal file
View file

@ -0,0 +1,10 @@
from sqlglot import exp
from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}

314
sqlglot/diff.py Normal file
View file

@ -0,0 +1,314 @@
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_list
@dataclass(frozen=True)
class Insert:
"""Indicates that a new node has been inserted"""
expression: exp.Expression
@dataclass(frozen=True)
class Remove:
"""Indicates that an existing node has been removed"""
expression: exp.Expression
@dataclass(frozen=True)
class Move:
"""Indicates that an existing node's position within the tree has changed"""
expression: exp.Expression
@dataclass(frozen=True)
class Update:
"""Indicates that an existing node has been updated"""
source: exp.Expression
target: exp.Expression
@dataclass(frozen=True)
class Keep:
"""Indicates that an existing node hasn't been changed"""
source: exp.Expression
target: exp.Expression
def diff(source, target):
"""
Returns the list of changes between the source and the target expressions.
Examples:
>>> diff(parse_one("a + b"), parse_one("a + c"))
[
Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))),
Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))),
Keep(
source=(ADD this: ...),
target=(ADD this: ...)
),
Keep(
source=(COLUMN this: (IDENTIFIER this: a, quoted: False)),
target=(COLUMN this: (IDENTIFIER this: a, quoted: False))
),
]
Args:
source (sqlglot.Expression): the source expression.
target (sqlglot.Expression): the target expression against which the diff should be calculated.
Returns:
the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
This list represents a sequence of steps needed to transform the source expression tree into the target one.
"""
return ChangeDistiller().diff(source.copy(), target.copy())
LEAF_EXPRESSION_TYPES = (
exp.Boolean,
exp.DataType,
exp.Identifier,
exp.Literal,
)
class ChangeDistiller:
"""
The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in
their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
def __init__(self, f=0.6, t=0.6):
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
def diff(self, source, target):
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
self._bigram_histo_cache = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
def _generate_edit_script(self, matching_set):
edit_script = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
edit_script.append(Insert(self._target_index[inserted_node_id]))
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if (
not isinstance(source_node, LEAF_EXPRESSION_TYPES)
or source_node == target_node
):
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
return edit_script
def _generate_move_edits(self, source, target, matching_set):
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
args_lcs = set(
_lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
)
move_edits = []
for a in source_args:
if a not in args_lcs and a not in self._unmatched_source_nodes:
move_edits.append(Move(self._source_index[a]))
return move_edits
def _compute_matching_set(self):
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
id(n[0]): None
for n in self._source.bfs()
if id(n[0]) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
id(n[0]): None
for n in self._target.bfs()
if id(n[0]) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
for target_node_id in ordered_unmatched_target_nodes:
source_node = self._source_index[source_node_id]
target_node = self._target_index[target_node_id]
if _is_same_type(source_node, target_node):
source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
1 if s in source_leaf_ids and t in target_leaf_ids else 0
for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
adjusted_t = (
self.t
if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
else 0.4
)
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
and self._dice_coefficient(source_node, target_node) >= self.f
):
matching_set.add((source_node_id, target_node_id))
self._unmatched_source_nodes.remove(source_node_id)
self._unmatched_target_nodes.remove(target_node_id)
ordered_unmatched_target_nodes.pop(target_node_id, None)
break
return matching_set
def _compute_leaf_matching_set(self):
candidate_matchings = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
for target_leaf in target_leaves:
if _is_same_type(source_leaf, target_leaf):
similarity_score = self._dice_coefficient(source_leaf, target_leaf)
if similarity_score >= self.f:
heappush(
candidate_matchings,
(
-similarity_score,
len(candidate_matchings),
source_leaf,
target_leaf,
),
)
# Pick best matchings based on the highest score
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
if (
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
):
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
return matching_set
def _dice_coefficient(self, source, target):
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
total_grams = sum(source_histo.values()) + sum(target_histo.values())
if not total_grams:
return 1.0 if source == target else 0.0
overlap_len = 0
overlapping_grams = set(source_histo) & set(target_histo)
for g in overlapping_grams:
overlap_len += min(source_histo[g], target_histo[g])
return 2 * overlap_len / total_grams
def _bigram_histo(self, expression):
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
bigram_histo = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
self._bigram_histo_cache[id(expression)] = bigram_histo
return bigram_histo
def _get_leaves(expression):
has_child_exprs = False
for a in expression.args.values():
nodes = ensure_list(a)
for node in nodes:
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source, target):
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
if isinstance(source, exp.Anonymous):
return source.this == target.this
return True
return False
def _expression_only_args(expression):
args = []
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]
def _lcs(seq_a, seq_b, equal):
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
len_b = len(seq_b)
lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)]
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
lcs_result[i][j] = []
elif equal(seq_a[i - 1], seq_b[j - 1]):
lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
else lcs_result[i][j - 1]
)
return lcs_result[len_a][len_b]

38
sqlglot/errors.py Normal file
View file

@ -0,0 +1,38 @@
from enum import auto
from sqlglot.helper import AutoName
class ErrorLevel(AutoName):
IGNORE = auto() # Ignore any parser errors
WARN = auto() # Log any parser errors with ERROR level
RAISE = auto() # Collect all parser errors and raise a single exception
IMMEDIATE = auto() # Immediately raise an exception on the first parser error
class SqlglotError(Exception):
pass
class UnsupportedError(SqlglotError):
pass
class ParseError(SqlglotError):
pass
class TokenError(SqlglotError):
pass
class OptimizeError(SqlglotError):
pass
def concat_errors(errors, maximum):
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:
msg.append(f"... and {remaining} more")
return "\n\n".join(msg)

View file

@ -0,0 +1,39 @@
import logging
import time
from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot")
def execute(sql, schema, read=None):
"""
Run a sql query against data.
Args:
sql (str): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
expression = parse_one(sql, read=read)
now = time.time()
expression = optimize(expression, schema)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
result = PythonExecutor().execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result

View file

@ -0,0 +1,68 @@
from sqlglot.executor.env import ENV
class Context:
"""
Execution context for sql expressions.
Context is used to hold relevant data tables which can then be queried on with eval.
References to columns can either be scalar or vectors. When set_row is used, column references
evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient
evaluation of aggregation functions.
"""
def __init__(self, tables, env=None):
"""
Args
tables (dict): table_name -> Table, representing the scope of the current execution context
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
self.range_readers = {
name: table.range_reader for name, table in self.tables.items()
}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
def eval(self, code):
return eval(code, ENV, self.env)
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
def __iter__(self):
return self.table_iter(list(self.tables)[0])
def table_iter(self, table):
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
def sort(self, table, key):
table = self.tables[table]
def sort_key(row):
table.reader.row = row
return self.eval_tuple(key)
table.rows.sort(key=sort_key)
def set_row(self, table, row):
self.row_readers[table].row = row
self.env["scope"] = self.row_readers
def set_index(self, table, index):
self.row_readers[table].row = self.tables[table].rows[index]
self.env["scope"] = self.row_readers
def set_range(self, table, start, end):
self.range_readers[table].range = range(start, end)
self.env["scope"] = self.range_readers
def __getitem__(self, table):
return self.env["scope"][table]
def __contains__(self, table):
return table in self.tables

32
sqlglot/executor/env.py Normal file
View file

@ -0,0 +1,32 @@
import datetime
import re
import statistics
class reverse_key:
def __init__(self, obj):
self.obj = obj
def __eq__(self, other):
return other.obj == self.obj
def __lt__(self, other):
return other.obj < self.obj
ENV = {
"__builtins__": {},
"datetime": datetime,
"locals": locals,
"re": re,
"float": float,
"int": int,
"str": str,
"desc": reverse_key,
"SUM": sum,
"AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max,
"MIN": min,
"POW": pow,
}

360
sqlglot/executor/python.py Normal file
View file

@ -0,0 +1,360 @@
import ast
import collections
import itertools
from sqlglot import exp, planner
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
from sqlglot.generator import Generator
from sqlglot.helper import csv_reader
from sqlglot.tokens import Tokenizer
class PythonExecutor:
def __init__(self, env=None):
self.generator = Python().generator(identify=True)
self.env = {**ENV, **(env or {})}
def execute(self, plan):
running = set()
finished = set()
queue = set(plan.leaves)
contexts = {}
while queue:
node = queue.pop()
context = self.context(
{
name: table
for dep in node.dependencies
for name, table in contexts[dep].tables.items()
}
)
running.add(node)
if isinstance(node, planner.Scan):
contexts[node] = self.scan(node, context)
elif isinstance(node, planner.Aggregate):
contexts[node] = self.aggregate(node, context)
elif isinstance(node, planner.Join):
contexts[node] = self.join(node, context)
elif isinstance(node, planner.Sort):
contexts[node] = self.sort(node, context)
else:
raise NotImplementedError
running.remove(node)
finished.add(node)
for dep in node.dependents:
if dep not in running and all(d in contexts for d in dep.dependencies):
queue.add(dep)
for dep in node.dependencies:
if all(d in finished for d in dep.dependents):
contexts.pop(dep)
root = plan.root
return contexts[root].tables[root.name]
def generate(self, expression):
"""Convert a SQL expression into literal Python code and compile it into bytecode."""
if not expression:
return None
sql = self.generator.generate(expression)
return compile(sql, sql, "eval", optimize=2)
def generate_tuple(self, expressions):
"""Convert an array of SQL expressions into tuple of Python byte code."""
if not expressions:
return tuple()
return tuple(self.generate(expression) for expression in expressions)
def context(self, tables):
return Context(tables, env=self.env)
def table(self, expressions):
return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context):
if hasattr(step, "source"):
source = step.source
if isinstance(source, exp.Expression):
source = source.this.name or source.alias
else:
source = step.name
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
if source in context:
if not projections and not condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
else:
table_iter = self.scan_csv(step)
if projections:
sink = self.table(step.projections)
elif source in context:
sink = Table(context[source].columns)
else:
sink = None
for reader, ctx in table_iter:
if sink is None:
sink = Table(ctx[source].columns)
if condition and not ctx.eval(condition):
continue
if projections:
sink.append(ctx.eval_tuple(projections))
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink})
def scan_csv(self, step):
source = step.source
alias = source.alias
with csv_reader(source.this) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
types = []
for row in reader:
if not types:
for v in row:
try:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
yield context[alias], context
def join(self, step, context):
source = step.name
join_context = self.context({source: context.tables[source]})
def merge_context(ctx, table):
# create a new context where all existing tables are mapped to a new one
return self.context({name: table for name in ctx.tables})
for name, join in step.joins.items():
join_context = self.context(
{**join_context.tables, name: context.tables[name]}
)
if join.get("source_key"):
table = self.hash_join(join, source, name, join_context)
else:
table = self.nested_loop_join(join, source, name, join_context)
join_context = merge_context(join_context, table)
# apply projections or conditions
context = self.scan(step, join_context)
# use the scan context since it returns a single table
# otherwise there are no projections so all other tables are still in scope
if step.projections:
return context
return merge_context(join_context, context.tables[source])
def nested_loop_join(self, _join, a, b, context):
table = Table(context.tables[a].columns + context.tables[b].columns)
for reader_a, _ in context.table_iter(a):
for reader_b, _ in context.table_iter(b):
table.append(reader_a.row + reader_b.row)
return table
def hash_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
results = collections.defaultdict(lambda: ([], []))
for reader, ctx in context.table_iter(a):
results[ctx.eval_tuple(a_key)][0].append(reader.row)
for reader, ctx in context.table_iter(b):
results[ctx.eval_tuple(b_key)][1].append(reader.row)
table = Table(context.tables[a].columns + context.tables[b].columns)
for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def sort_merge_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
context.sort(a, a_key)
context.sort(b, b_key)
a_i = 0
b_i = 0
a_n = len(context.tables[a])
b_n = len(context.tables[b])
table = Table(context.tables[a].columns + context.tables[b].columns)
def get_key(source, key, i):
context.set_index(source, i)
return context.eval_tuple(key)
while a_i < a_n and b_i < b_n:
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
a_group = []
while a_i < a_n and key == get_key(a, a_key, a_i):
a_group.append(context[a].row)
a_i += 1
b_group = []
while b_i < b_n and key == get_key(b, b_key, b_i):
b_group.append(context[b].row)
b_i += 1
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
source = step.source
group_by = self.generate_tuple(step.group)
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
context.sort(source, group_by)
if step.operands:
source_table = context.tables[source]
operand_table = Table(
source_table.columns + self.table(step.operands).columns
)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
context = self.context({source: operand_table})
group = None
start = 0
end = 1
length = len(context.tables[source])
table = self.table(step.group + step.aggregations)
for i in range(length):
context.set_index(source, i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
context.set_range(source, start, end - 1)
elif key != group:
context.set_range(source, start, end - 2)
else:
continue
table.append(group + context.eval_tuple(aggregations))
group = key
start = end - 2
return self.scan(step, self.context({source: table}))
def sort(self, step, context):
table = list(context.tables)[0]
key = self.generate_tuple(step.key)
context.sort(table, key)
return self.scan(step, context)
def _cast_py(self, expression):
to = expression.args["to"].this
this = self.sql(expression, "this")
if to == exp.DataType.Type.DATE:
return f"datetime.date.fromisoformat({this})"
if to == exp.DataType.Type.TEXT:
return f"str({this})"
raise NotImplementedError
def _column_py(self, expression):
table = self.sql(expression, "table")
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
def _interval_py(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper()
if unit == "DAY":
return f"datetime.timedelta(days=float({this}))"
raise NotImplementedError
def _like_py(self, expression):
this = self.sql(expression, "this")
expression = self.sql(expression, "expression")
return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})"""
def _ordered_py(self, expression):
this = self.sql(expression, "this")
desc = expression.args.get("desc")
return f"desc({this})" if desc else this
class Python(Dialect):
class Tokenizer(Tokenizer):
ESCAPE = "\\"
class Generator(Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),
exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
def case_sql(self, expression):
this = self.sql(expression, "this")
chain = self.sql(expression, "default") or "None"
for e in reversed(expression.args["ifs"]):
true = self.sql(e, "true")
condition = self.sql(e, "this")
condition = f"{this} = ({condition})" if this else condition
chain = f"{true} if {condition} else ({chain})"
return chain

81
sqlglot/executor/table.py Normal file
View file

@ -0,0 +1,81 @@
class Table:
def __init__(self, *columns, rows=None):
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self)
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
def pop(self):
self.rows.pop()
@property
def width(self):
return len(self.columns)
def __len__(self):
return len(self.rows)
def __iter__(self):
return TableIter(self)
def __getitem__(self, index):
self.reader.row = self.rows[index]
return self.reader
def __repr__(self):
widths = {column: len(column) for column in self.columns}
lines = [" ".join(column for column in self.columns)]
for i, row in enumerate(self):
if i > 10:
break
lines.append(
" ".join(
str(row[column]).rjust(widths[column])[0 : widths[column]]
for column in self.columns
)
)
return "\n".join(lines)
class TableIter:
def __init__(self, table):
self.table = table
self.index = -1
def __iter__(self):
return self
def __next__(self):
self.index += 1
if self.index < len(self.table):
return self.table[self.index]
raise StopIteration
class RangeReader:
def __init__(self, table):
self.table = table
self.range = range(0)
def __len__(self):
return len(self.range)
def __getitem__(self, column):
return (self.table[i][column] for i in self.range)
class RowReader:
def __init__(self, columns):
self.columns = {column: i for i, column in enumerate(columns)}
self.row = None
def __getitem__(self, column):
return self.row[self.columns[column]]

2945
sqlglot/expressions.py Normal file

File diff suppressed because it is too large Load diff

1124
sqlglot/generator.py Normal file

File diff suppressed because it is too large Load diff

123
sqlglot/helper.py Normal file
View file

@ -0,0 +1,123 @@
import logging
import re
from contextlib import contextmanager
from enum import Enum
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
logger = logging.getLogger("sqlglot")
class AutoName(Enum):
def _generate_next_value_(name, _start, _count, _last_values):
return name
def list_get(arr, index):
try:
return arr[index]
except IndexError:
return None
def ensure_list(value):
if value is None:
return []
return value if isinstance(value, (list, tuple, set)) else [value]
def csv(*args, sep=", "):
return sep.join(arg for arg in args if arg)
def apply_index_offset(expressions, offset):
if not offset or len(expressions) != 1:
return expressions
expression = expressions[0]
if expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset)
return [expression]
return expressions
def camel_to_snake_case(name):
return CAMEL_CASE_PATTERN.sub("_", name).upper()
def while_changing(expression, func):
while True:
start = hash(expression)
expression = func(expression)
if start == hash(expression):
break
return expression
def tsort(dag):
result = []
def visit(node, visited):
if node in result:
return
if node in visited:
raise ValueError("Cycle error")
visited.add(node)
for dep in dag.get(node, []):
visit(dep, visited)
visited.remove(node)
result.append(node)
for node in dag:
visit(node, set())
return result
def open_file(file_name):
"""
Open a file that may be compressed as gzip and return in newline mode.
"""
with open(file_name, "rb") as f:
gzipped = f.read(2) == b"\x1f\x8b"
if gzipped:
import gzip
return gzip.open(file_name, "rt", newline="")
return open(file_name, "rt", encoding="utf-8", newline="")
@contextmanager
def csv_reader(table):
"""
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
Args:
expression (Expression): An anonymous function READ_CSV
Returns:
A python csv reader.
"""
file, *args = table.this.expressions
file = file.name
file = open_file(file)
delimiter = ","
args = iter(arg.name for arg in args)
for k, v in zip(args, args):
if k == "delimiter":
delimiter = v
try:
import csv as csv_
yield csv_.reader(file, delimiter=delimiter)
finally:
file.close()

View file

@ -0,0 +1,2 @@
from sqlglot.optimizer.optimizer import optimize
from sqlglot.optimizer.schema import Schema

View file

@ -0,0 +1,48 @@
import itertools
from sqlglot import alias, exp, select, table
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
Rewrite duplicate subqueries from sqlglot AST.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
>>> eliminate_subqueries(expression).sql()
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
"""
expression = simplify(expression)
queries = {}
for scope in traverse_scope(expression):
query = scope.expression
queries[query] = queries.get(query, []) + [query]
sequence = itertools.count()
for query, duplicates in queries.items():
if len(duplicates) == 1:
continue
alias_ = f"_e_{next(sequence)}"
for dup in duplicates:
parent = dup.parent
if isinstance(parent, exp.Subquery):
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
elif isinstance(parent, exp.Union):
dup.replace(select("*").from_(alias_))
expression.with_(alias_, as_=query, copy=False)
return expression

View file

@ -0,0 +1,16 @@
from sqlglot import exp
def expand_multi_table_selects(expression):
for from_ in expression.find_all(exp.From):
parent = from_.parent
for query in from_.expressions[1:]:
parent.join(
query,
join_type="CROSS",
copy=False,
)
from_.expressions.remove(query)
return expression

View file

@ -0,0 +1,31 @@
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope
def isolate_table_selects(expression):
for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1:
continue
for (_, source) in scope.selected_sources.values():
if not isinstance(source, exp.Table):
continue
if not isinstance(source.parent, exp.Alias):
raise OptimizeError(
"Tables require an alias. Run qualify_tables optimization."
)
parent = source.parent
parent.replace(
exp.select("*")
.from_(
alias(source, source.name or parent.alias, table=True),
copy=False,
)
.subquery(parent.alias, copy=False)
)
return expression

View file

@ -0,0 +1,136 @@
from sqlglot import exp
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
def normalize(expression, dnf=False, max_distance=128):
"""
Rewrite sqlglot AST into conjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression).sql()
'(x OR z) AND (y OR z)'
Args:
expression (sqlglot.Expression): expression to normalize
dnf (bool): rewrite in disjunctive normal form instead
max_distance (int): the maximal estimated distance from cnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
expression = simplify(expression)
expression = while_changing(
expression, lambda e: distributive_law(e, dnf, max_distance)
)
return simplify(expression)
def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
)
def normalization_distance(expression, dnf=False):
"""
The difference in the number of predicates between the current expression and the normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Args:
expression (sqlglot.Expression): expression to compute distance
dnf (bool): compute to dnf distance instead
Returns:
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf):
"""
Returns a list of predicate lengths when expanded to normalized form.
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
"""
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return [1]
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
x = [
a + b
for a in _predicate_lengths(left, dnf)
for b in _predicate_lengths(right, dnf)
]
return x
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
if isinstance(expression.unnest(), exp.Connector):
if normalization_distance(expression, dnf) > max_distance:
return expression
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
from_func = exp.and_ if from_exp == exp.And else exp.or_
to_func = exp.and_ if to_exp == exp.And else exp.or_
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(
tuple(b.find_all(exp.Connector))
):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func)
return expression
def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
exp.paren(from_func(c, b.left)),
exp.paren(from_func(c, b.right)),
),
)
else:
a = to_func(from_func(a, b.left), from_func(a, b.right))
return _simplify(a)
def _simplify(node):
node = uniq_sort(flatten(node))
exp.replace_children(node, _simplify)
return node

View file

@ -0,0 +1,75 @@
from sqlglot import exp
from sqlglot.helper import tsort
from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
"""
Removes cross joins if possible and reorder joins based on predicate dependencies.
"""
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
name = join.this.alias_or_name
tables = other_table_names(join, name)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
cross_joins.append((name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.TRUE)
join.on(predicate, copy=False)
expression = reorder_joins(expression)
expression = normalize(expression)
return expression
def reorder_joins(expression):
"""
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.expressions[0]
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
for name, join in joins.items():
dag[name] = other_table_names(join, name)
parent.set(
"joins",
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
)
return expression
def normalize(expression):
"""
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
if join.kind != "CROSS":
join.set("kind", None)
return expression
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -0,0 +1,43 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
def optimize(expression, schema=None, db=None, catalog=None):
"""
Rewrite a sqlglot AST into an optimized form.
Args:
expression (sqlglot.Expression): expression to optimize
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
Returns:
sqlglot.Expression: optimized expression
"""
expression = expression.copy()
expression = qualify_tables(expression, db=db, catalog=catalog)
expression = isolate_table_selects(expression)
expression = qualify_columns(expression, schema)
expression = pushdown_projections(expression)
expression = normalize(expression)
expression = unnest_subqueries(expression)
expression = expand_multi_table_selects(expression)
expression = pushdown_predicates(expression)
expression = optimize_joins(expression)
expression = eliminate_subqueries(expression)
expression = quote_identities(expression)
return expression

View file

@ -0,0 +1,176 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
for scope in reversed(traverse_scope(expression)):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
return expression
def pushdown(condition, sources):
if not condition:
return
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources)
else:
pushdown_dnf(predicates, sources)
def pushdown_cnf(predicates, scope):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.TRUE)
node.where(replace_aliases(node, predicate), copy=False)
def pushdown_dnf(predicates, scope):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
"""
# find all the tables that can be pushdown too
# these are tables that are referenced in all blocks of a DNF
# (a.x AND b.x) OR (a.y AND c.y)
# only table a can be push down
pushdown_tables = set()
for a in predicates:
a_tables = set(exp.column_table_names(a))
for b in predicates:
a_tables &= set(exp.column_table_names(b))
pushdown_tables.update(a_tables)
conditions = {}
# for every pushdown table, find all related conditions in all predicates
# combine them with ORS
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope)
if table not in nodes:
continue
predicate_condition = None
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
for name, node in nodes.items():
if name not in conditions:
continue
predicate = conditions[name]
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
def nodes_for_predicate(predicate, sources):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
)
for table in tables:
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
# we want to find the root join or from statement
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be push down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression
if isinstance(node, exp.Join):
if node.side:
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
if not node.args.get("group"):
nodes[table] = node
return nodes
def replace_aliases(source, predicate):
aliases = {}
for select in source.selects:
if isinstance(select, exp.Alias):
aliases[select.alias] = select.this
else:
aliases[select.name] = select
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name]
return column
return predicate.transform(_replace_alias)

View file

@ -0,0 +1,85 @@
from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
def pushdown_projections(expression):
"""
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
if scope.expression.args.get("distinct"):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union
referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections
if isinstance(scope.expression, exp.Select):
_remove_unused_selections(scope, parent_selections)
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:
table_name = col.table
col_name = col.name
selects[table_name].add(col_name)
# Push the selected columns down to the next scope
for name, (_, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = selects.get(name) or set()
referenced_columns[source].update(columns)
return expression
def _remove_unused_selections(scope, parent_selections):
order = scope.expression.args.get("order")
if order:
# Assume columns without a qualified table are references to output columns
order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
else:
order_refs = set()
new_selections = []
for selection in scope.selects:
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(alias("1", "_"))
scope.expression.set("expressions", new_selections)

View file

@ -0,0 +1,422 @@
import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
def qualify_columns(expression, schema):
"""
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
for scope in traverse_scope(expression):
resolver = _Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, SKIP_QUALIFY):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
return expression
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, SKIP_QUALIFY):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
def _expand_using(scope, resolver):
joins = list(scope.expression.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to source names
column_tables = {}
for join in joins:
using = join.args.get("using")
if not using:
continue
join_table = join.this.alias_or_name
columns = {}
for k in scope.selected_sources:
if k in ordered:
for column in resolver.get_source_columns(k):
if column not in columns:
columns[column] = k
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
for identifier in using:
identifier = identifier.name
table = columns.get(identifier)
if not table or identifier not in join_columns:
raise OptimizeError(f"Cannot automatically join: {identifier}")
conditions.append(
exp.condition(
exp.EQ(
this=exp.column(identifier, table=table),
expression=exp.column(identifier, table=join_table),
)
)
)
tables = column_tables.setdefault(identifier, [])
if table not in tables:
tables.append(table)
if join_table not in tables:
tables.append(join_table)
join.args.pop("using")
join.set("on", exp.and_(*conditions))
if column_tables:
for column in scope.columns:
if not column.table and column.name in column_tables:
tables = column_tables[column.name]
coalesce = [exp.column(column.name, table=table) for table in tables]
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
replacement = exp.alias_(replacement, alias=column.name)
scope.replace(column, replacement)
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
# Replace references to select aliases
def transform(node, *_):
if isinstance(node, exp.Column) and not node.table:
table = resolver.get_table(node.name)
# Source columns get priority over select aliases
if table:
node.set("table", exp.to_identifier(table))
return node
selects = {s.alias_or_name: s for s in scope.selects}
select = selects.get(node.name)
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
select = select.this
return select.copy()
return node
group.transform(transform, copy=False)
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
def _expand_order_by(scope):
order = scope.expression.args.get("order")
if not order:
return
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds)),
):
ordered.set("this", new_expression)
def _expand_positional_references(scope, expressions):
new_nodes = []
for node in expressions:
if node.is_int:
try:
select = scope.selects[int(node.name) - 1]
except IndexError:
raise OptimizeError(f"Unknown output column: {node.name}")
if isinstance(select, exp.Alias):
select = select.this
new_nodes.append(select.copy())
scope.clear_cache()
else:
new_nodes.append(node)
return new_nodes
def _qualify_columns(scope, resolver):
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
column_name = column.name
if (
column_table
and column_table in scope.sources
and column_name not in resolver.get_source_columns(column_table)
):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_unnest:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
elif isinstance(expression, exp.Column) and isinstance(
expression.this, exp.Star
):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
else:
new_selections.append(expression)
continue
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table)
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(
alias(column, alias_) if alias_ != name else column
)
scope.expression.set("expressions", new_selections)
def _add_except_columns(expression, tables, except_columns):
except_ = expression.args.get("except")
if not except_:
return
columns = {e.name for e in except_}
for table in tables:
except_columns[id(table)] = columns
def _add_replace_columns(expression, tables, replace_columns):
replace = expression.args.get("replace")
if not replace:
return
columns = {e.this.name: e.alias for e in replace}
for table in tables:
replace_columns[id(table)] = columns
def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)
selection = alias_
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
new_selections.append(selection)
scope.expression.set("expressions", new_selections)
def _check_unknown_tables(scope):
if (
scope.external_columns
and not scope.is_unnest
and not scope.is_correlated_subquery
):
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
class _Resolver:
"""
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
"""
def __init__(self, scope, schema):
self.scope = scope
self.schema = schema
self._source_columns = None
self._unambiguous_columns = None
self._all_columns = None
def get_table(self, column_name):
"""
Get the table for a column name.
Args:
column_name (str)
Returns:
(str) table name
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = set(
column
for columns in self._get_all_source_columns().values()
for column in columns
)
return self._all_columns
def get_source_columns(self, name):
"""Resolve the source columns for a given source `name`"""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
try:
return self.schema.column_names(source)
except Exception as e:
raise OptimizeError(str(e)) from e
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
"""
Find all the unambiguous columns in sources.
Args:
source_columns (dict): Mapping of names to source columns
Returns:
dict: Mapping of column name to source name
"""
if not source_columns:
return {}
source_columns = list(source_columns.items())
first_table, first_columns = source_columns[0]
unambiguous_columns = {
col: first_table for col in self._find_unique_columns(first_columns)
}
all_columns = set(unambiguous_columns)
for table, columns in source_columns[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
all_columns.update(columns)
for column in ambiguous:
unambiguous_columns.pop(column, None)
for column in unique.difference(ambiguous):
unambiguous_columns[column] = table
return unambiguous_columns
@staticmethod
def _find_unique_columns(columns):
"""
Find the unique columns in a list of columns.
Example:
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
['a', 'c']
This is necessary because duplicate column names are ambiguous.
"""
counts = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}

View file

@ -0,0 +1,54 @@
import itertools
from sqlglot import alias, exp
from sqlglot.optimizer.scope import traverse_scope
def qualify_tables(expression, db=None, catalog=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
Returns:
sqlglot.Expression: qualified expression
"""
sequence = itertools.count()
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set(
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
)
scope.rename_source(None, alias_)
for source in scope.sources.values():
if isinstance(source, exp.Table):
identifier = isinstance(source.this, exp.Identifier)
if identifier:
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
if not isinstance(source.parent, exp.Alias):
source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
table=True,
)
)
return expression

View file

@ -0,0 +1,25 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

129
sqlglot/optimizer/schema.py Normal file
View file

@ -0,0 +1,129 @@
import abc
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import csv_reader
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def column_names(self, table):
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
Returns:
list[str]: list of column names
"""
class MappingSchema(Schema):
"""
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
"""
def __init__(self, schema):
self.schema = schema
depth = _dict_depth(schema)
if not depth: # {}
self.supported_table_args = []
elif depth == 2: # {table: {col: type}}
self.supported_table_args = ("this",)
elif depth == 3: # {db: {table: {col: type}}}
self.supported_table_args = ("db", "this")
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
self.supported_table_args = ("catalog", "db", "this")
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def column_names(self, table):
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
args = tuple(table.text(p) for p in self.supported_table_args)
for forbidden in self.forbidden_args:
if table.text(forbidden):
raise ValueError(
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
)
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
def ensure_schema(schema):
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def fs_get(table):
name = table.this.name.upper()
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path):
"""
Get a value for a nested dictionary.
Args:
d (dict): dictionary
*path (tuple[str, str]): tuples of (name, key)
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
"""
for name, key in path:
d = d.get(key)
if d is None:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}")
return d
def _dict_depth(d):
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

438
sqlglot/optimizer/scope.py Normal file
View file

@ -0,0 +1,438 @@
from copy import copy
from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
class ScopeType(Enum):
ROOT = auto()
SUBQUERY = auto()
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
class Scope:
"""
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
This does not include derived tables or CTEs.
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
a tuple of the left and right child scopes.
"""
def __init__(
self,
expression,
sources=None,
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
):
self.expression = expression
self.sources = sources or {}
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.union = None
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
self._tables = None
self._ctes = None
self._subqueries = None
self._selected_sources = None
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
parent=self,
scope_type=scope_type,
**kwargs,
)
def _collect(self):
self._tables = []
self._ctes = []
self._subqueries = []
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(
parent, (exp.From, exp.Join)
):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
def _ensure_collected(self):
if not self._collected:
self._collect()
def replace(self, old, new):
"""
Replace `old` with `new`.
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
Args:
old (exp.Expression): old node
new (exp.Expression): new node
"""
old.replace(new)
self.clear_cache()
@property
def tables(self):
"""
List of tables in this scope.
Returns:
list[exp.Table]: tables
"""
self._ensure_collected()
return self._tables
@property
def ctes(self):
"""
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
"""
self._ensure_collected()
return self._ctes
@property
def derived_tables(self):
"""
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
"""
self._ensure_collected()
return self._derived_tables
@property
def subqueries(self):
"""
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
"""
self._ensure_collected()
return self._subqueries
@property
def columns(self):
"""
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any
Columns that reference this scope from correlated subqueries.
"""
if self._columns is None:
self._ensure_collected()
columns = self._raw_columns
external_columns = [
column
for scope in self.subquery_scopes
for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
if not (
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
)
]
return self._columns
@property
def selected_sources(self):
"""
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a
table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
referenced_names = []
for table in self.tables:
referenced_names.append(
(
table.parent.alias
if isinstance(table.parent, exp.Alias)
else table.name,
table,
)
)
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
result = {}
for name, node in referenced_names:
if name in self.sources:
result[name] = (node, self.sources[name])
self._selected_sources = result
return self._selected_sources
@property
def selects(self):
"""
Select expressions of this scope.
For example, for the following expression:
SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
"""
if isinstance(self.expression, exp.Union):
return []
return self.expression.selects
@property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Args:
source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference `source_name`
"""
return [column for column in self.columns if column.table == source_name]
@property
def is_subquery(self):
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(self.is_subquery and self.external_columns)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
def traverse_scope(expression):
"""
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than
the expression tree itself. For example, we might care about the source
names within a subquery. Returns a list because a generator could result in
incomplete properties which is confusing.
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
Returns:
List[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
)
_add_table_sources(scope)
def _traverse_union(scope):
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
):
yield left
right = None
for right in _traverse_scope(
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
):
yield right
scope.union = (left, right)
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
for child_scope in _traverse_scope(
scope.branch(
derived_table
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST
if isinstance(derived_table, exp.Unnest)
else scope_type,
)
):
yield child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
scope.sources.update(sources)
def _add_table_sources(scope):
sources = {}
for table in scope.tables:
table_name = table.name
if isinstance(table.parent, exp.Alias):
source_name = table.parent.alias
else:
source_name = table_name
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table
scope.sources.update(sources)
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
for child_scope in _traverse_scope(
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)

View file

@ -0,0 +1,383 @@
import datetime
import functools
import itertools
from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import while_changing
GENERATOR = Generator(normalize=True, identify=True)
def simplify(expression):
"""
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Args:
expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
"""
def _simplify(expression, root=True):
node = expression
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node)
node = remove_compliments(node)
node.parent = expression.parent
node = simplify_literals(node)
node = simplify_parens(node)
if root:
expression.replace(node)
return node
expression = while_changing(expression, _simplify)
remove_where_true(expression)
return expression
def simplify_not(expression):
"""
Demorgan's Law
NOT (x OR y) -> NOT x AND NOT y
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if always_true(expression.this):
return FALSE
if expression.this == FALSE:
return TRUE
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
return expression.this.this
return expression
def flatten(expression):
"""
A AND (B AND C) -> A AND B AND C
A OR (B OR C) -> A OR B OR C
"""
if isinstance(expression, exp.Connector):
for node in expression.args.values():
child = node.unnest()
if isinstance(child, expression.__class__):
node.replace(child)
return expression
def simplify_connectors(expression):
if isinstance(expression, exp.Connector):
left = expression.left
right = expression.right
if left == right:
return left
if isinstance(expression, exp.And):
if NULL in (left, right):
return NULL
if FALSE in (left, right):
return FALSE
if always_true(left) and always_true(right):
return TRUE
if always_true(left):
return right
if always_true(right):
return left
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return TRUE
if left == FALSE and right == FALSE:
return FALSE
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return NULL
if left == FALSE:
return right
if right == FALSE:
return left
return expression
def remove_compliments(expression):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
compliment = FALSE if isinstance(expression, exp.And) else TRUE
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return compliment
return expression
def uniq_sort(expression):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
expression = result_func(*deduped.values())
return expression
def absorb_and_eliminate(expression):
"""
absorption:
A AND (A OR B) -> A
A OR (A AND B) -> A
A AND (NOT A OR B) -> A AND B
A OR (NOT A AND B) -> A OR B
elimination:
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
if isinstance(a, kind):
aa, ab = a.unnest_operands()
# absorb
if is_complement(b, aa):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
a.flatten()
):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
ba, bb = rhs
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
a.replace(aa)
b.replace(aa)
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
a.replace(ab)
b.replace(ab)
return expression
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = _simplify_binary(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
if value[0] == "-":
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
return expression
def _simplify_binary(expression, a, b):
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
not_ = True
else:
c = b
not_ = False
if c == NULL:
if isinstance(a, exp.Literal):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
elif NULL in (a, b):
return NULL
if isinstance(expression, exp.EQ) and a == b:
return TRUE
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
return exp.Literal.number(a + b)
if isinstance(expression, exp.Sub):
return exp.Literal.number(a - b)
if isinstance(expression, exp.Mul):
return exp.Literal.number(a * b)
if isinstance(expression, exp.Div):
if isinstance(a, int) and isinstance(b, int):
return exp.Literal.number(a // b)
return exp.Literal.number(a / b)
boolean = eval_boolean(expression, a, b)
if boolean:
return boolean
elif a.is_string and b.is_string:
boolean = eval_boolean(expression, a, b)
if boolean:
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and isinstance(expression, exp.Add):
return date_literal(a + b)
return None
def simplify_parens(expression):
if (
isinstance(expression, exp.Paren)
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
or isinstance(expression.this, (exp.Is, exp.Like))
or not isinstance(expression.this, exp.Binary)
)
):
return expression.this
return expression
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
where.parent.set("where", None)
for join in expression.find_all(exp.Join):
if always_true(join.args.get("on")):
join.set("kind", "CROSS")
join.set("on", None)
def always_true(expression):
return expression == TRUE or isinstance(expression, exp.Literal)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
if isinstance(expression, exp.NEQ):
return boolean_literal(a != b)
if isinstance(expression, exp.GT):
return boolean_literal(a > b)
if isinstance(expression, exp.GTE):
return boolean_literal(a >= b)
if isinstance(expression, exp.LT):
return boolean_literal(a < b)
if isinstance(expression, exp.LTE):
return boolean_literal(a <= b)
return None
def extract_date(cast):
if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name)
return None
def extract_interval(interval):
try:
from dateutil.relativedelta import relativedelta
except ModuleNotFoundError:
return None
n = int(interval.name)
unit = interval.text("unit").lower()
if unit == "year":
return relativedelta(years=n)
if unit == "month":
return relativedelta(months=n)
if unit == "week":
return relativedelta(weeks=n)
if unit == "day":
return relativedelta(days=n)
return None
def date_literal(date):
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
def boolean_literal(condition):
return TRUE if condition else FALSE

View file

@ -0,0 +1,220 @@
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
"""
sequence = itertools.count()
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
predicate = predicate.find_ancestor(exp.EQ)
if not predicate or parent_select is not predicate.parent_select:
return
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
parent_select.join(
select.group_by(value.this, copy=False),
on=on,
join_type="LEFT",
join_alias=alias,
copy=False,
)
def decorrelate(select, parent_select, external_columns, sequence):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
predicate = column.find_ancestor(exp.Predicate)
if not predicate or predicate.find_ancestor(exp.Where) is not where:
return
if isinstance(predicate, exp.Binary):
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else:
return
keys.append((key, column, predicate))
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
value = select.selects[0]
key_aliases = {}
group_by = []
for key, _, predicate in keys:
# if we filter on the value of the subquery, it needs to be unique
if key == value.this:
key_aliases[key] = value.alias
group_by.append(key)
else:
if key not in key_aliases:
key_aliases[key] = _alias(sequence)
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
group_by.append(key)
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
if isinstance(parent_predicate, exp.Exists):
select.args["expressions"] = []
for key, alias in key_aliases.items():
if key in group_by:
# add all keys to the projections of the subquery
# so that we can use it as a join key
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else:
parent_predicate = _replace(
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
else:
parent_predicate = _replace(
parent_predicate,
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.TRUE)
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:
key.replace(nested)
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
)
else:
key.replace(exp.to_identifier("_x"))
parent_predicate = _replace(
parent_predicate,
f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
)
parent_select.join(
select.group_by(*group_by, copy=False),
on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
join_type="LEFT",
join_alias=table_alias,
copy=False,
)
def _alias(sequence):
return f"_u_{next(sequence)}"
def _replace(expression, condition):
return expression.replace(exp.condition(condition))
def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
if isinstance(expression, exp.Binary):
return expression.right if expression.arg_key == "this" else expression.left
return None

2190
sqlglot/parser.py Normal file

File diff suppressed because it is too large Load diff

340
sqlglot/planner.py Normal file
View file

@ -0,0 +1,340 @@
import itertools
import math
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
from sqlglot.optimizer.simplify import simplify
class Plan:
def __init__(self, expression):
self.expression = expression
self.root = Step.from_expression(self.expression)
self._dag = {}
@property
def dag(self):
if not self._dag:
dag = {}
nodes = {self.root}
while nodes:
node = nodes.pop()
dag[node] = set()
for dep in node.dependencies:
dag[node].add(dep)
nodes.add(dep)
self._dag = dag
return self._dag
@property
def leaves(self):
return (node for node, deps in self.dag.items() if not deps)
class Step:
@classmethod
def from_expression(cls, expression, ctes=None):
"""
Build a DAG of Steps from a SQL expression.
Giving an expression like:
SELECT x.a, SUM(x.b)
FROM x
JOIN y
ON x.a = y.a
GROUP BY x.a
Transform it into a DAG of the form:
Aggregate(x.a, SUM(x.b))
Join(y)
Scan(x)
Scan(y)
This can then more easily be executed on by an engine.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
# CTEs break the mold of scope and introduce themselves to all in the context.
if with_:
ctes = ctes.copy()
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
ctes[step.name] = step
from_ = expression.args.get("from")
if from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
"Multi-from statements are unsupported. Run it through the optimizer"
)
step = Scan.from_expression(from_[0], ctes)
else:
raise UnsupportedError("Static selects are unsupported.")
joins = expression.args.get("joins")
if joins:
join = Join.from_joins(joins, ctes)
join.name = step.name
join.add_dependency(step)
step = join
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
sequence = itertools.count()
for e in expression.expressions:
aggregation = e.find(exp.AggFunc)
if aggregation:
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
for operand in aggregation.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(
exp.column(operands[operand], step.name, quoted=True)
)
else:
projections.append(e)
where = expression.args.get("where")
if where:
step.condition = where.this
group = expression.args.get("group")
if group:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
aggregate.group = [
exp.column(e.alias_or_name, step.name, quoted=True)
for e in group.expressions
]
aggregate.add_dependency(step)
step = aggregate
having = expression.args.get("having")
if having:
step.condition = having.this
order = expression.args.get("order")
if order:
sort = Sort()
sort.name = step.name
sort.key = order.expressions
sort.add_dependency(step)
step = sort
for k in sort.key + projections:
for column in k.find_all(exp.Column):
column.set("table", exp.to_identifier(step.name, quoted=True))
step.projections = projections
limit = expression.args.get("limit")
if limit:
step.limit = int(limit.text("expression"))
return step
def __init__(self):
self.name = None
self.dependencies = set()
self.dependents = set()
self.projections = []
self.limit = math.inf
self.condition = None
def add_dependency(self, dependency):
self.dependencies.add(dependency)
dependency.dependents.add(self)
def __repr__(self):
return self.to_s()
def to_s(self, level=0):
indent = " " * level
nested = f"{indent} "
context = self._to_s(f"{nested} ")
if context:
context = [f"{nested}Context:"] + context
lines = [
f"{indent}- {self.__class__.__name__}: {self.name}",
*context,
f"{nested}Projections:",
]
for expression in self.projections:
lines.append(f"{nested} - {expression.sql()}")
if self.condition:
lines.append(f"{nested}Condition: {self.condition.sql()}")
if self.dependencies:
lines.append(f"{nested}Dependencies:")
for dependency in self.dependencies:
lines.append(" " + dependency.to_s(level + 1))
return "\n".join(lines)
def _to_s(self, _indent):
return []
class Scan(Step):
@classmethod
def from_expression(cls, expression, ctes=None):
table = expression.this
alias_ = expression.alias
if not alias_:
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
if isinstance(expression, exp.Subquery):
step = Step.from_expression(table, ctes)
step.name = alias_
return step
step = Scan()
step.name = alias_
step.source = expression
if table.name in ctes:
step.add_dependency(ctes[table.name])
return step
def __init__(self):
super().__init__()
self.source = None
def _to_s(self, indent):
return [f"{indent}Source: {self.source.sql()}"]
class Write(Step):
pass
class Join(Step):
@classmethod
def from_joins(cls, joins, ctes=None):
step = Join()
for join in joins:
name = join.this.alias
on = join.args.get("on") or exp.TRUE
source_key = []
join_key = []
# find the join keys
# SELECT
# FROM x
# JOIN y
# ON x.a = y.b AND y.b > 1
#
# should pull y.b as the join key and x.a as the source key
for condition in on.flatten() if isinstance(on, exp.And) else [on]:
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
right_tables = exp.column_table_names(right)
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
condition.replace(exp.TRUE)
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
condition.replace(exp.TRUE)
on = simplify(on)
step.joins[name] = {
"side": join.side,
"join_key": join_key,
"source_key": source_key,
"condition": None if on == exp.TRUE else on,
}
step.add_dependency(Scan.from_expression(join.this, ctes))
return step
def __init__(self):
super().__init__()
self.joins = {}
def _to_s(self, indent):
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}")
return lines
class Aggregate(Step):
def __init__(self):
super().__init__()
self.aggregations = []
self.operands = []
self.group = []
self.source = None
def _to_s(self, indent):
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
lines.append(f"{indent} - {expression.sql()}")
if self.group:
lines.append(f"{indent}Group:")
for expression in self.group:
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
for expression in self.operands:
lines.append(f"{indent} - {expression.sql()}")
return lines
class Sort(Step):
def __init__(self):
super().__init__()
self.key = None
def _to_s(self, indent):
lines = [f"{indent}Key:"]
for expression in self.key:
lines.append(f"{indent} - {expression.sql()}")
return lines

45
sqlglot/time.py Normal file
View file

@ -0,0 +1,45 @@
# the generic time format is based on python time.strftime
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import in_trie, new_trie
def format_time(string, mapping, trie=None):
"""
Converts a time string given a mapping.
Examples:
>>> format_time("%Y", {"%Y": "YYYY"})
'YYYY'
mapping: Dictionary of time format to target time format
trie: Optional trie, can be passed in for performance
"""
start = 0
end = 1
size = len(string)
trie = trie or new_trie(mapping)
current = trie
chunks = []
sym = None
while end <= size:
chars = string[start:end]
result, current = in_trie(current, chars[-1])
if result == 0:
if sym:
end -= 1
chars = sym
sym = None
start += len(chars)
chunks.append(chars)
current = trie
elif result == 2:
sym = chars
end += 1
if result and end > size:
chunks.append(chars)
return "".join(mapping.get(chars, chars) for chars in chunks)

853
sqlglot/tokens.py Normal file
View file

@ -0,0 +1,853 @@
from enum import auto
from sqlglot.helper import AutoName
from sqlglot.trie import in_trie, new_trie
class TokenType(AutoName):
L_PAREN = auto()
R_PAREN = auto()
L_BRACKET = auto()
R_BRACKET = auto()
L_BRACE = auto()
R_BRACE = auto()
COMMA = auto()
DOT = auto()
DASH = auto()
PLUS = auto()
COLON = auto()
DCOLON = auto()
SEMICOLON = auto()
STAR = auto()
SLASH = auto()
LT = auto()
LTE = auto()
GT = auto()
GTE = auto()
NOT = auto()
EQ = auto()
NEQ = auto()
AND = auto()
OR = auto()
AMP = auto()
DPIPE = auto()
PIPE = auto()
CARET = auto()
TILDA = auto()
ARROW = auto()
DARROW = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
SPACE = auto()
BREAK = auto()
STRING = auto()
NUMBER = auto()
IDENTIFIER = auto()
COLUMN = auto()
COLUMN_DEF = auto()
SCHEMA = auto()
TABLE = auto()
VAR = auto()
BIT_STRING = auto()
# types
BOOLEAN = auto()
TINYINT = auto()
SMALLINT = auto()
INT = auto()
BIGINT = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
CHAR = auto()
NCHAR = auto()
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
BYTEA = auto()
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
DATETIME = auto()
DATE = auto()
UUID = auto()
GEOGRAPHY = auto()
NULLABLE = auto()
# keywords
ADD_FILE = auto()
ALIAS = auto()
ALL = auto()
ALTER = auto()
ANALYZE = auto()
ANY = auto()
ARRAY = auto()
ASC = auto()
AT_TIME_ZONE = auto()
AUTO_INCREMENT = auto()
BEGIN = auto()
BETWEEN = auto()
BUCKET = auto()
CACHE = auto()
CALL = auto()
CASE = auto()
CAST = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMENT = auto()
COMMIT = auto()
CONSTRAINT = auto()
CONVERT = auto()
CREATE = auto()
CROSS = auto()
CUBE = auto()
CURRENT_DATE = auto()
CURRENT_DATETIME = auto()
CURRENT_ROW = auto()
CURRENT_TIME = auto()
CURRENT_TIMESTAMP = auto()
DIV = auto()
DEFAULT = auto()
DELETE = auto()
DESC = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DROP = auto()
ELSE = auto()
END = auto()
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
EXISTS = auto()
EXPLAIN = auto()
EXTRACT = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
FINAL = auto()
FIRST = auto()
FOLLOWING = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
FULL = auto()
FUNCTION = auto()
FROM = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
HINT = auto()
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
INSERT = auto()
INTERSECT = auto()
INTERVAL = auto()
INTO = auto()
INTRODUCER = auto()
IS = auto()
ISNULL = auto()
JOIN = auto()
LATERAL = auto()
LAZY = auto()
LEFT = auto()
LIKE = auto()
LIMIT = auto()
LOCATION = auto()
MAP = auto()
MOD = auto()
NEXT = auto()
NO_ACTION = auto()
NULL = auto()
NULLS_FIRST = auto()
NULLS_LAST = auto()
OFFSET = auto()
ON = auto()
ONLY = auto()
OPTIMIZE = auto()
OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
OUTER = auto()
OUT_OF = auto()
OVER = auto()
OVERWRITE = auto()
PARTITION = auto()
PARTITION_BY = auto()
PARTITIONED_BY = auto()
PERCENT = auto()
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROPERTIES = auto()
QUALIFY = auto()
QUOTE = auto()
RANGE = auto()
RECURSIVE = auto()
REPLACE = auto()
RESPECT_NULLS = auto()
REFERENCES = auto()
RIGHT = auto()
RLIKE = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
SCHEMA_COMMENT = auto()
SELECT = auto()
SET = auto()
SHOW = auto()
SOME = auto()
SORT_BY = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TIME = auto()
TOP = auto()
THEN = auto()
TRUE = auto()
TRUNCATE = auto()
TRY_CAST = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
UNNEST = auto()
UPDATE = auto()
USE = auto()
USING = auto()
VALUES = auto()
VIEW = auto()
WHEN = auto()
WHERE = auto()
WINDOW = auto()
WITH = auto()
WITH_TIME_ZONE = auto()
WITHIN_GROUP = auto()
WITHOUT_TIME_ZONE = auto()
UNIQUE = auto()
class Token:
__slots__ = ("token_type", "text", "line", "col")
@classmethod
def number(cls, number):
return cls(TokenType.NUMBER, str(number))
@classmethod
def string(cls, string):
return cls(TokenType.STRING, string)
@classmethod
def identifier(cls, identifier):
return cls(TokenType.IDENTIFIER, identifier)
@classmethod
def var(cls, var):
return cls(TokenType.VAR, var)
def __init__(self, token_type, text, line=1, col=1):
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
def __repr__(self):
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass.QUOTES = dict(
(quote, quote) if isinstance(quote, str) else (quote[0], quote[1])
for quote in klass.QUOTES
)
klass.IDENTIFIERS = dict(
(identifier, identifier)
if isinstance(identifier, str)
else (identifier[0], identifier[1])
for identifier in klass.IDENTIFIERS
)
klass.COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
for key, value in {
**klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass.COMMENTS},
**{quote: TokenType.QUOTE for quote in klass.QUOTES},
}.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
class Tokenizer(metaclass=_Tokenizer):
SINGLE_TOKENS = {
"(": TokenType.L_PAREN,
")": TokenType.R_PAREN,
"[": TokenType.L_BRACKET,
"]": TokenType.R_BRACKET,
"{": TokenType.L_BRACE,
"}": TokenType.R_BRACE,
"&": TokenType.AMP,
"^": TokenType.CARET,
":": TokenType.COLON,
",": TokenType.COMMA,
".": TokenType.DOT,
"-": TokenType.DASH,
"=": TokenType.EQ,
">": TokenType.GT,
"<": TokenType.LT,
"%": TokenType.MOD,
"!": TokenType.NOT,
"|": TokenType.PIPE,
"+": TokenType.PLUS,
";": TokenType.SEMICOLON,
"/": TokenType.SLASH,
"*": TokenType.STAR,
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION,
"$": TokenType.DOLLAR,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
}
QUOTES = ["'"]
IDENTIFIERS = ['"']
ESCAPE = "'"
KEYWORDS = {
"/*+": TokenType.HINT,
"*/": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
"||": TokenType.DPIPE,
">=": TokenType.GTE,
"<=": TokenType.LTE,
"<>": TokenType.NEQ,
"!=": TokenType.NEQ,
"->": TokenType.ARROW,
"->>": TokenType.DARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"ADD ARCHIVE": TokenType.ADD_FILE,
"ADD ARCHIVES": TokenType.ADD_FILE,
"ADD FILE": TokenType.ADD_FILE,
"ADD FILES": TokenType.ADD_FILE,
"ADD JAR": TokenType.ADD_FILE,
"ADD JARS": TokenType.ADD_FILE,
"ALL": TokenType.ALL,
"ALTER": TokenType.ALTER,
"ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND,
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
"BUCKET": TokenType.BUCKET,
"CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CAST": TokenType.CAST,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COMMENT": TokenType.SCHEMA_COMMENT,
"COMMIT": TokenType.COMMIT,
"CONSTRAINT": TokenType.CONSTRAINT,
"CONVERT": TokenType.CONVERT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
"CUBE": TokenType.CUBE,
"CURRENT_DATE": TokenType.CURRENT_DATE,
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
"DIV": TokenType.DIV,
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
"END": TokenType.END,
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
"EXISTS": TokenType.EXISTS,
"EXPLAIN": TokenType.EXPLAIN,
"EXTRACT": TokenType.EXTRACT,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
"FIRST": TokenType.FIRST,
"FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION,
"FOLLOWING": TokenType.FOLLOWING,
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
"INNER": TokenType.INNER,
"INSERT": TokenType.INSERT,
"INTERVAL": TokenType.INTERVAL,
"INTERSECT": TokenType.INTERSECT,
"INTO": TokenType.INTO,
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEFT": TokenType.LEFT,
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
"NOT": TokenType.NOT,
"NULL": TokenType.NULL,
"NULLS FIRST": TokenType.NULLS_FIRST,
"NULLS LAST": TokenType.NULLS_LAST,
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
"OPTIMIZE": TokenType.OPTIMIZE,
"OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
"ORDINALITY": TokenType.ORDINALITY,
"OUTER": TokenType.OUTER,
"OUT OF": TokenType.OUT_OF,
"OVER": TokenType.OVER,
"OVERWRITE": TokenType.OVERWRITE,
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
"PARTITIONED BY": TokenType.PARTITIONED_BY,
"PERCENT": TokenType.PERCENT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"RANGE": TokenType.RANGE,
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
"REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
"SELECT": TokenType.SELECT,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
"TBLPROPERTIES": TokenType.PROPERTIES,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRUNCATE": TokenType.TRUNCATE,
"TRY_CAST": TokenType.TRY_CAST,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNNEST": TokenType.UNNEST,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"WHEN": TokenType.WHEN,
"WHERE": TokenType.WHERE,
"WITH": TokenType.WITH,
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"ARRAY": TokenType.ARRAY,
"BOOL": TokenType.BOOLEAN,
"BOOLEAN": TokenType.BOOLEAN,
"BYTE": TokenType.TINYINT,
"TINYINT": TokenType.TINYINT,
"SHORT": TokenType.SMALLINT,
"SMALLINT": TokenType.SMALLINT,
"INT2": TokenType.SMALLINT,
"INTEGER": TokenType.INT,
"INT": TokenType.INT,
"INT4": TokenType.INT,
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.BIGINT,
"DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP,
"NUMBER": TokenType.DECIMAL,
"NUMERIC": TokenType.DECIMAL,
"FIXED": TokenType.DECIMAL,
"REAL": TokenType.FLOAT,
"FLOAT": TokenType.FLOAT,
"FLOAT4": TokenType.FLOAT,
"FLOAT8": TokenType.DOUBLE,
"DOUBLE": TokenType.DOUBLE,
"JSON": TokenType.JSON,
"CHAR": TokenType.CHAR,
"NCHAR": TokenType.NCHAR,
"VARCHAR": TokenType.VARCHAR,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
"BINARY": TokenType.BINARY,
"BLOB": TokenType.BINARY,
"BYTEA": TokenType.BINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
}
WHITE_SPACE = {
" ": TokenType.SPACE,
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
"\r": TokenType.BREAK,
"\r\n": TokenType.BREAK,
}
COMMANDS = {
TokenType.ALTER,
TokenType.ADD_FILE,
TokenType.ANALYZE,
TokenType.BEGIN,
TokenType.CALL,
TokenType.COMMIT,
TokenType.EXPLAIN,
TokenType.OPTIMIZE,
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.USE,
}
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS = {}
ENCODE = None
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
__slots__ = (
"sql",
"size",
"tokens",
"_start",
"_current",
"_line",
"_col",
"_char",
"_end",
"_peek",
)
def __init__(self):
"""
Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
"""
self.reset()
def reset(self):
self.sql = ""
self.size = 0
self.tokens = []
self._start = 0
self._current = 0
self._line = 1
self._col = 1
self._char = None
self._end = None
self._peek = None
def tokenize(self, sql):
self.reset()
self.sql = sql
self.size = len(sql)
while self.size and not self._end:
self._start = self._current
self._advance()
if not self._char:
break
white_space = self.WHITE_SPACE.get(self._char)
identifier_end = self.IDENTIFIERS.get(self._char)
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
elif self._char == "0" and self._peek == "x":
self._scan_hex()
elif self._char.isdigit():
self._scan_number()
elif identifier_end:
self._scan_identifier(identifier_end)
else:
self._scan_keywords()
return self.tokens
def _chars(self, size):
if size == 1:
return self._char
start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""
def _advance(self, i=1):
self._col += i
self._current += i
self._end = self._current >= self.size
self._char = self.sql[self._current - 1]
self._peek = self.sql[self._current] if self._current < self.size else ""
@property
def _text(self):
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
text = self._text if text is None else text
self.tokens.append(Token(token_type, text, self._line, self._col))
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
def _scan_keywords(self):
size = 0
word = None
chars = self._text
char = chars
prev_space = False
skip = False
trie = self.KEYWORD_TRIE
while chars:
if skip:
result = 1
else:
result, trie = in_trie(trie, char.upper())
if result == 0:
break
if result == 2:
word = chars
size += 1
end = self._current - 1 + size
if end < self.size:
char = self.sql[end]
is_space = char in self.WHITE_SPACE
if not is_space or not prev_space:
if is_space:
char = " "
chars += char
prev_space = is_space
skip = False
else:
skip = True
else:
chars = None
if not word:
if self._char in self.SINGLE_TOKENS:
token = self.SINGLE_TOKENS[self._char]
if token == TokenType.ANNOTATION:
self._scan_annotation()
return
self._add(token)
return
self._scan_var()
return
if self._scan_string(word):
return
if self._scan_comment(word):
return
self._advance(size - 1)
self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start):
if comment_start not in self.COMMENTS:
return False
comment_end = self.COMMENTS[comment_start]
if comment_end:
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
self._advance()
return True
def _scan_annotation(self):
while (
not self._end
and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
and self._peek != ","
):
self._advance()
self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self):
decimal = False
scientific = 0
while True:
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
decimal = True
self._advance()
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
elif self._peek.upper() == "E" and not scientific:
scientific += 1
self._advance()
elif self._peek.isalpha():
self._add(TokenType.NUMBER)
literal = []
while self._peek.isalpha():
literal.append(self._peek.upper())
self._advance()
literal = "".join(literal)
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
if token_type:
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)
def _scan_hex(self):
self._advance()
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
break
try:
self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}")
except ValueError:
self._add(TokenType.IDENTIFIER)
def _scan_string(self, quote):
quote_end = self.QUOTES.get(quote)
if quote_end is None:
return False
text = ""
self._advance(len(quote))
quote_end_size = len(quote_end)
while True:
if self._char == self.ESCAPE and self._peek == quote_end:
text += quote
self._advance(2)
else:
if self._chars(quote_end_size) == quote_end:
if quote_end_size > 1:
self._advance(quote_end_size - 1)
break
if self._end:
raise RuntimeError(
f"Missing {quote} from {self._line}:{self._start}"
)
text += self._char
self._advance()
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
self._add(TokenType.STRING, text)
return True
def _scan_identifier(self, identifier_end):
while self._peek != identifier_end:
if self._end:
raise RuntimeError(
f"Missing {identifier_end} from {self._line}:{self._start}"
)
self._advance()
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
def _scan_var(self):
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
break
self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))

68
sqlglot/transforms.py Normal file
View file

@ -0,0 +1,68 @@
from sqlglot import expressions as exp
def unalias_group(expression):
"""
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
e.alias: i
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
}
expression = expression.copy()
for col in expression.find_all(exp.Column):
alias_index = aliased_selects.get(col.name)
if not col.table and alias_index:
col.replace(exp.Literal.number(alias_index))
return expression
def preprocess(transforms, to_sql):
"""
Create a new transform function that can be used a value in `Generator.TRANSFORMS`
to convert expressions to SQL.
Args:
transforms (list[(exp.Expression) -> exp.Expression]):
Sequence of transform functions. These will be called in order.
to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
Final transform that converts the resulting expression to a SQL string.
Returns:
(sqlglot.generator.Generator, exp.Expression) -> str:
Function that can be used as a generator transform.
"""
def _to_sql(self, expression):
expression = transforms[0](expression)
for t in transforms[1:]:
expression = t(expression)
return to_sql(self, expression)
return _to_sql
def delegate(attr):
"""
Create a new method that delegates to `attr`.
This is useful for creating `Generator.TRANSFORMS` functions that delegate
to existing generator methods.
"""
def _transform(self, *args, **kwargs):
return getattr(self, attr)(*args, **kwargs)
return _transform
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}

27
sqlglot/trie.py Normal file
View file

@ -0,0 +1,27 @@
def new_trie(keywords):
trie = {}
for key in keywords:
current = trie
for char in key:
current = current.setdefault(char, {})
current[0] = True
return trie
def in_trie(trie, key):
if not key:
return (0, trie)
current = trie
for char in key:
if char not in current:
return (0, current)
current = current[char]
if 0 in current:
return (2, current)
return (1, current)

0
tests/__init__.py Normal file
View file

View file

View file

@ -0,0 +1,238 @@
from sqlglot import ErrorLevel, ParseError, UnsupportedError, transpile
from tests.dialects.test_dialect import Validator
class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
self.validate_all(
'"""x"""',
write={
"bigquery": "'x'",
"duckdb": "'x'",
"presto": "'x'",
"hive": "'x'",
"spark": "'x'",
},
)
self.validate_all(
'"""x\'"""',
write={
"bigquery": "'x\\''",
"duckdb": "'x'''",
"presto": "'x'''",
"hive": "'x\\''",
"spark": "'x\\''",
},
)
self.validate_all(
r'r"""/\*.*\*/"""',
write={
"bigquery": r"'/\\*.*\\*/'",
"duckdb": r"'/\*.*\*/'",
"presto": r"'/\*.*\*/'",
"hive": r"'/\\*.*\\*/'",
"spark": r"'/\\*.*\\*/'",
},
)
self.validate_all(
R'R"""/\*.*\*/"""',
write={
"bigquery": R"'/\\*.*\\*/'",
"duckdb": R"'/\*.*\*/'",
"presto": R"'/\*.*\*/'",
"hive": R"'/\\*.*\\*/'",
"spark": R"'/\\*.*\\*/'",
},
)
self.validate_all(
"CAST(a AS INT64)",
write={
"bigquery": "CAST(a AS INT64)",
"duckdb": "CAST(a AS BIGINT)",
"presto": "CAST(a AS BIGINT)",
"hive": "CAST(a AS BIGINT)",
"spark": "CAST(a AS LONG)",
},
)
self.validate_all(
"CAST(a AS NUMERIC)",
write={
"bigquery": "CAST(a AS NUMERIC)",
"duckdb": "CAST(a AS DECIMAL)",
"presto": "CAST(a AS DECIMAL)",
"hive": "CAST(a AS DECIMAL)",
"spark": "CAST(a AS DECIMAL)",
},
)
self.validate_all(
"[1, 2, 3]",
read={
"duckdb": "LIST_VALUE(1, 2, 3)",
"presto": "ARRAY[1, 2, 3]",
"hive": "ARRAY(1, 2, 3)",
"spark": "ARRAY(1, 2, 3)",
},
write={
"bigquery": "[1, 2, 3]",
"duckdb": "LIST_VALUE(1, 2, 3)",
"presto": "ARRAY[1, 2, 3]",
"hive": "ARRAY(1, 2, 3)",
"spark": "ARRAY(1, 2, 3)",
},
)
self.validate_all(
"SELECT * FROM UNNEST(['7', '14']) AS x",
read={
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
},
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x",
"presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS (x)",
"hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
},
)
self.validate_all(
"x IS unknown",
write={
"bigquery": "x IS NULL",
"duckdb": "x IS NULL",
"presto": "x IS NULL",
"hive": "x IS NULL",
"spark": "x IS NULL",
},
)
self.validate_all(
"current_datetime",
write={
"bigquery": "CURRENT_DATETIME()",
"duckdb": "CURRENT_DATETIME()",
"presto": "CURRENT_DATETIME()",
"hive": "CURRENT_DATETIME()",
"spark": "CURRENT_DATETIME()",
},
)
self.validate_all(
"current_time",
write={
"bigquery": "CURRENT_TIME()",
"duckdb": "CURRENT_TIME()",
"presto": "CURRENT_TIME()",
"hive": "CURRENT_TIME()",
"spark": "CURRENT_TIME()",
},
)
self.validate_all(
"current_timestamp",
write={
"bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP()",
"hive": "CURRENT_TIMESTAMP()",
"spark": "CURRENT_TIMESTAMP()",
},
)
self.validate_all(
"current_timestamp()",
write={
"bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()",
"postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP()",
"hive": "CURRENT_TIMESTAMP()",
"spark": "CURRENT_TIMESTAMP()",
},
)
self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
)
self.validate_identity(
"SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)",
)
self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)",
write={
"bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRING>)",
"duckdb": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b TEXT>)",
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)",
},
)
self.validate_all(
"CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
write={
"bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a BIGINT, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a BIGINT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: LONG, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)",
},
)
self.validate_all(
"SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
write={
"bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
"mysql": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))",
"presto": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY[1, 2, 3]))",
"hive": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))",
"spark": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))",
},
)
# Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(UnsupportedError):
transpile(
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)
with self.assertRaises(ParseError):
transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery")
self.validate_all(
"DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)",
write={
"postgres": "CURRENT_DATE - INTERVAL '1' DAY",
},
)
self.validate_all(
"DATE_ADD(CURRENT_DATE(), INTERVAL 1 DAY)",
write={
"bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)",
"duckdb": "CURRENT_DATE + INTERVAL 1 DAY",
"mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)",
"postgres": "CURRENT_DATE + INTERVAL '1' DAY",
"presto": "DATE_ADD(DAY, 1, CURRENT_DATE)",
"hive": "DATE_ADD(CURRENT_DATE, 1)",
"spark": "DATE_ADD(CURRENT_DATE, 1)",
},
)
self.validate_all(
"CURRENT_DATE('UTC')",
write={
"mysql": "CURRENT_DATE AT TIME ZONE 'UTC'",
"postgres": "CURRENT_DATE AT TIME ZONE 'UTC'",
},
)
self.validate_all(
"SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10",
write={
"bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10",
"snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10",
},
)

View file

@ -0,0 +1,25 @@
from tests.dialects.test_dialect import Validator
class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
self.validate_identity("dictGet(x, 'y')")
self.validate_identity("SELECT * FROM x FINAL")
self.validate_identity("SELECT * FROM x AS y FINAL")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
"clickhouse": "CAST(1 AS Nullable(BIGINT))",
},
)

View file

@ -0,0 +1,981 @@
import unittest
from sqlglot import (
Dialect,
Dialects,
ErrorLevel,
UnsupportedError,
parse_one,
transpile,
)
class Validator(unittest.TestCase):
dialect = None
def validate(self, sql, target, **kwargs):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
Validate that:
1. Everything in `read` transpiles to `sql`
2. `sql` transpiles to everything in `write`
Args:
sql (str): Main SQL expression
dialect (str): dialect of `sql`
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
expression = parse_one(sql, read=self.dialect)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE
),
sql,
)
for write_dialect, write_sql in (write or {}).items():
with self.subTest(f"{sql} -> {write_dialect}"):
if write_sql is UnsupportedError:
with self.assertRaises(UnsupportedError):
expression.sql(
write_dialect, unsupported_level=ErrorLevel.RAISE
)
else:
self.assertEqual(
expression.sql(
write_dialect,
unsupported_level=ErrorLevel.IGNORE,
pretty=pretty,
),
write_sql,
)
class TestDialect(Validator):
maxDiff = None
def test_enum(self):
for dialect in Dialects:
self.assertIsNotNone(Dialect[dialect])
self.assertIsNotNone(Dialect.get(dialect))
self.assertIsNotNone(Dialect.get_or_raise(dialect))
self.assertIsNotNone(Dialect[dialect.value])
def test_cast(self):
self.validate_all(
"CAST(a AS TEXT)",
write={
"bigquery": "CAST(a AS STRING)",
"clickhouse": "CAST(a AS TEXT)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
},
)
self.validate_all(
"CAST(a AS STRING)",
write={
"bigquery": "CAST(a AS STRING)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
},
)
self.validate_all(
"CAST(a AS VARCHAR)",
write={
"bigquery": "CAST(a AS STRING)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS VARCHAR)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS VARCHAR2)",
"postgres": "CAST(a AS VARCHAR)",
"presto": "CAST(a AS VARCHAR)",
"snowflake": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS VARCHAR)",
},
)
self.validate_all(
"CAST(a AS VARCHAR(3))",
write={
"bigquery": "CAST(a AS STRING(3))",
"duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))",
"hive": "CAST(a AS VARCHAR(3))",
"oracle": "CAST(a AS VARCHAR2(3))",
"postgres": "CAST(a AS VARCHAR(3))",
"presto": "CAST(a AS VARCHAR(3))",
"snowflake": "CAST(a AS VARCHAR(3))",
"spark": "CAST(a AS VARCHAR(3))",
"starrocks": "CAST(a AS VARCHAR(3))",
},
)
self.validate_all(
"CAST(a AS SMALLINT)",
write={
"bigquery": "CAST(a AS INT64)",
"duckdb": "CAST(a AS SMALLINT)",
"mysql": "CAST(a AS SMALLINT)",
"hive": "CAST(a AS SMALLINT)",
"oracle": "CAST(a AS NUMBER)",
"postgres": "CAST(a AS SMALLINT)",
"presto": "CAST(a AS SMALLINT)",
"snowflake": "CAST(a AS SMALLINT)",
"spark": "CAST(a AS SHORT)",
"sqlite": "CAST(a AS INTEGER)",
"starrocks": "CAST(a AS SMALLINT)",
},
)
self.validate_all(
"CAST(a AS DOUBLE)",
write={
"bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
"oracle": "CAST(a AS DOUBLE PRECISION)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"presto": "CAST(a AS DOUBLE)",
"snowflake": "CAST(a AS DOUBLE)",
"spark": "CAST(a AS DOUBLE)",
"starrocks": "CAST(a AS DOUBLE)",
},
)
self.validate_all(
"CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"}
)
self.validate_all(
"CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"}
)
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all("CAST(a AS BIGINT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all("CAST(a AS INT)", write={"oracle": "CAST(a AS NUMBER)"})
self.validate_all(
"CAST(a AS DECIMAL)",
read={"oracle": "CAST(a AS NUMBER)"},
write={"oracle": "CAST(a AS NUMBER)"},
)
def test_time(self):
self.validate_all(
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
read={
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
},
write={
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
},
)
self.validate_all(
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
write={
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
},
)
self.validate_all(
"STR_TO_TIME(x, '%y')",
write={
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
"spark": "TO_TIMESTAMP(x, 'yy')",
},
)
self.validate_all(
"STR_TO_UNIX('2020-01-01', '%Y-%M-%d')",
write={
"duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))",
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))",
},
)
self.validate_all(
"TIME_STR_TO_DATE('2020-01-01')",
write={
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
},
)
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01')",
write={
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
},
)
self.validate_all(
"TIME_STR_TO_UNIX('2020-01-01')",
write={
"duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))",
"hive": "UNIX_TIMESTAMP('2020-01-01')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))",
},
)
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
},
)
self.validate_all(
"TIME_TO_TIME_STR(x)",
write={
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
},
)
self.validate_all(
"TIME_TO_UNIX(x)",
write={
"duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)",
},
)
self.validate_all(
"TS_OR_DS_TO_DATE_STR(x)",
write={
"duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)",
"hive": "SUBSTRING(CAST(x AS STRING), 1, 10)",
"presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)",
},
)
self.validate_all(
"TS_OR_DS_TO_DATE(x)",
write={
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)",
},
)
self.validate_all(
"TS_OR_DS_TO_DATE(x, '%-d')",
write={
"duckdb": "CAST(STRPTIME(x, '%-d') AS DATE)",
"hive": "TO_DATE(x, 'd')",
"presto": "CAST(DATE_PARSE(x, '%e') AS DATE)",
"spark": "TO_DATE(x, 'd')",
},
)
self.validate_all(
"UNIX_TO_STR(x, y)",
write={
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)",
"hive": "FROM_UNIXTIME(x, y)",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
},
)
self.validate_all(
"UNIX_TO_TIME(x)",
write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"hive": "FROM_UNIXTIME(x)",
"presto": "FROM_UNIXTIME(x)",
},
)
self.validate_all(
"UNIX_TO_TIME_STR(x)",
write={
"duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)",
"hive": "FROM_UNIXTIME(x)",
"presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)",
},
)
self.validate_all(
"DATE_TO_DATE_STR(x)",
write={
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
},
)
self.validate_all(
"DATE_TO_DI(x)",
write={
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
},
)
self.validate_all(
"DI_TO_DATE(x)",
write={
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
},
)
self.validate_all(
"TS_OR_DI_TO_DI(x)",
write={
"duckdb": "CAST(SUBSTR(REPLACE(CAST(x AS TEXT), '-', ''), 1, 8) AS INT)",
"hive": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)",
"presto": "CAST(SUBSTR(REPLACE(CAST(x AS VARCHAR), '-', ''), 1, 8) AS INT)",
"spark": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)",
},
)
self.validate_all(
"DATE_ADD(x, 1, 'day')",
read={
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"postgres": "x + INTERVAL '1' 'day'",
"presto": "DATE_ADD('day', 1, x)",
"spark": "DATE_ADD(x, 1)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
)
self.validate_all(
"DATE_ADD(x, y, 'day')",
write={
"postgres": UnsupportedError,
},
)
self.validate_all(
"DATE_ADD(x, 1)",
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"presto": "DATE_ADD('day', 1, x)",
"spark": "DATE_ADD(x, 1)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
)
self.validate_all(
"DATE_TRUNC(x, 'day')",
write={
"mysql": "DATE(x)",
"starrocks": "DATE(x)",
},
)
self.validate_all(
"DATE_TRUNC(x, 'week')",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'month')",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'quarter')",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'year')",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'millenium')",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%dT%H:%M:%S')",
read={
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
},
write={
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)",
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
},
)
self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%d')",
write={
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
},
)
self.validate_all(
"DATE_STR_TO_DATE(x)",
write={
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
},
)
self.validate_all(
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
write={
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
"spark": "DATE_ADD('2021-02-01', 1)",
},
)
self.validate_all(
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
write={
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
f"{unit}(x)",
read={
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
"duckdb",
"mysql",
"presto",
"starrocks",
)
},
write={
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
"duckdb",
"mysql",
"presto",
"hive",
"spark",
"starrocks",
)
},
)
def test_array(self):
self.validate_all(
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"ARRAY_SIZE(x)",
write={
"bigquery": "ARRAY_LENGTH(x)",
"duckdb": "ARRAY_LENGTH(x)",
"presto": "CARDINALITY(x)",
"spark": "SIZE(x)",
},
)
self.validate_all(
"ARRAY_SUM(ARRAY(1, 2))",
write={
"trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)",
"duckdb": "LIST_SUM(LIST_VALUE(1, 2))",
"hive": "ARRAY_SUM(ARRAY(1, 2))",
"presto": "ARRAY_SUM(ARRAY[1, 2])",
"spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)",
},
)
self.validate_all(
"REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
write={
"trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
},
)
def test_order_by(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
},
)
def test_json(self):
self.validate_all(
"JSON_EXTRACT(x, 'y')",
read={
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')",
},
write={
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')",
},
)
self.validate_all(
"JSON_EXTRACT_SCALAR(x, 'y')",
read={
"postgres": "x->>'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')",
},
write={
"postgres": "x->>'y'",
"presto": "JSON_EXTRACT_SCALAR(x, 'y')",
},
)
self.validate_all(
"JSONB_EXTRACT(x, 'y')",
read={
"postgres": "x#>'y'",
},
write={
"postgres": "x#>'y'",
},
)
self.validate_all(
"JSONB_EXTRACT_SCALAR(x, 'y')",
read={
"postgres": "x#>>'y'",
},
write={
"postgres": "x#>>'y'",
},
)
def test_cross_join(self):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
)
self.validate_all(
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
write={
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
},
)
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t (a)",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
},
)
def test_set_operators(self):
self.validate_all(
"SELECT * FROM a UNION SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a UNION ALL SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b",
"presto": "SELECT * FROM a UNION ALL SELECT * FROM b",
"spark": "SELECT * FROM a UNION ALL SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b",
"presto": "SELECT * FROM a UNION ALL SELECT * FROM b",
"spark": "SELECT * FROM a UNION ALL SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a INTERSECT SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a EXCEPT SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a UNION DISTINCT SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
},
)
self.validate_all(
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
},
)
def test_operators(self):
self.validate_all(
"x ILIKE '%y'",
read={
"clickhouse": "x ILIKE '%y'",
"duckdb": "x ILIKE '%y'",
"postgres": "x ILIKE '%y'",
"snowflake": "x ILIKE '%y'",
},
write={
"bigquery": "LOWER(x) LIKE '%y'",
"clickhouse": "x ILIKE '%y'",
"duckdb": "x ILIKE '%y'",
"hive": "LOWER(x) LIKE '%y'",
"mysql": "LOWER(x) LIKE '%y'",
"oracle": "LOWER(x) LIKE '%y'",
"postgres": "x ILIKE '%y'",
"presto": "LOWER(x) LIKE '%y'",
"snowflake": "x ILIKE '%y'",
"spark": "LOWER(x) LIKE '%y'",
"sqlite": "LOWER(x) LIKE '%y'",
"starrocks": "LOWER(x) LIKE '%y'",
"trino": "LOWER(x) LIKE '%y'",
},
)
self.validate_all(
"SELECT * FROM a ORDER BY col_a NULLS LAST",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"STR_POSITION(x, 'a')",
write={
"duckdb": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
},
)
self.validate_all(
"CONCAT_WS('-', 'a', 'b')",
write={
"duckdb": "CONCAT_WS('-', 'a', 'b')",
"presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')",
"hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')",
},
)
self.validate_all(
"CONCAT_WS('-', x)",
write={
"duckdb": "CONCAT_WS('-', x)",
"presto": "ARRAY_JOIN(x, '-')",
"hive": "CONCAT_WS('-', x)",
"spark": "CONCAT_WS('-', x)",
},
)
self.validate_all(
"IF(x > 1, 1, 0)",
write={
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
"presto": "IF(x > 1, 1, 0)",
"hive": "IF(x > 1, 1, 0)",
"spark": "IF(x > 1, 1, 0)",
"tableau": "IF x > 1 THEN 1 ELSE 0 END",
},
)
self.validate_all(
"CASE WHEN 1 THEN x ELSE 0 END",
write={
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
"spark": "CASE WHEN 1 THEN x ELSE 0 END",
"tableau": "CASE WHEN 1 THEN x ELSE 0 END",
},
)
self.validate_all(
"x[y]",
write={
"duckdb": "x[y]",
"presto": "x[y]",
"hive": "x[y]",
"spark": "x[y]",
},
)
self.validate_all(
"""'["x"]'""",
write={
"duckdb": """'["x"]'""",
"presto": """'["x"]'""",
"hive": """'["x"]'""",
"spark": """'["x"]'""",
},
)
self.validate_all(
'true or null as "foo"',
write={
"bigquery": "TRUE OR NULL AS `foo`",
"duckdb": 'TRUE OR NULL AS "foo"',
"presto": 'TRUE OR NULL AS "foo"',
"hive": "TRUE OR NULL AS `foo`",
"spark": "TRUE OR NULL AS `foo`",
},
)
self.validate_all(
"SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz",
write={
"bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz",
"duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz",
"presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
"hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
"spark": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz",
},
)
self.validate_all(
"LEVENSHTEIN(col1, col2)",
write={
"duckdb": "LEVENSHTEIN(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
"hive": "LEVENSHTEIN(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
},
)
self.validate_all(
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
},
)
self.validate_all(
"ARRAY_FILTER(the_array, x -> x > 0)",
write={
"presto": "FILTER(the_array, x -> x > 0)",
"hive": "FILTER(the_array, x -> x > 0)",
"spark": "FILTER(the_array, x -> x > 0)",
},
)
self.validate_all(
"SELECT a AS b FROM x GROUP BY b",
write={
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
"oracle": "SELECT a AS b FROM x GROUP BY 1",
"spark": "SELECT a AS b FROM x GROUP BY 1",
},
)
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
"sqlite": "SELECT x FROM y LIMIT 10",
"oracle": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY",
},
)
self.validate_all(
"SELECT x FROM y LIMIT 10 OFFSET 5",
write={
"sqlite": "SELECT x FROM y LIMIT 10 OFFSET 5",
"oracle": "SELECT x FROM y OFFSET 5 ROWS FETCH FIRST 10 ROWS ONLY",
},
)
self.validate_all(
"SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY",
write={
"oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
},
)
self.validate_all(
"SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
write={
"oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
},
)
self.validate_all(
'"x" + "y"',
read={
"clickhouse": '`x` + "y"',
"sqlite": '`x` + "y"',
},
)
self.validate_all(
"[1, 2]",
write={
"bigquery": "[1, 2]",
"clickhouse": "[1, 2]",
},
)
self.validate_all(
"SELECT * FROM VALUES ('x'), ('y') AS t(z)",
write={
"spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)",
},
)
self.validate_all(
"CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR, v2 VARCHAR2, nv NVARCHAR, nv2 NVARCHAR2)",
write={
"hive": "CREATE TABLE t (c CHAR, nc CHAR, v1 STRING, v2 STRING, nv STRING, nv2 STRING)",
"oracle": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR2, v2 VARCHAR2, nv NVARCHAR2, nv2 NVARCHAR2)",
"postgres": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR, v2 VARCHAR, nv VARCHAR, nv2 VARCHAR)",
"sqlite": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)",
},
)
self.validate_all(
"POWER(1.2, 3.4)",
read={
"hive": "pow(1.2, 3.4)",
"postgres": "power(1.2, 3.4)",
},
)
self.validate_all(
"CREATE INDEX my_idx ON tbl (a, b)",
read={
"hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
"sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
},
write={
"hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
"postgres": "CREATE INDEX my_idx ON tbl (a, b)",
"sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
},
)
self.validate_all(
"CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
read={
"hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
"sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
},
write={
"hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
"postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
"sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
},
)
self.validate_all(
"CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))",
write={
"hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))",
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
},
)

View file

@ -0,0 +1,249 @@
from tests.dialects.test_dialect import Validator
class TestDuckDB(Validator):
dialect = "duckdb"
def test_time(self):
self.validate_all(
"EPOCH(x)",
read={
"presto": "TO_UNIXTIME(x)",
},
write={
"bigquery": "TIME_TO_UNIX(x)",
"duckdb": "EPOCH(x)",
"presto": "TO_UNIXTIME(x)",
"spark": "UNIX_TIMESTAMP(x)",
},
)
self.validate_all(
"EPOCH_MS(x)",
write={
"bigquery": "UNIX_TO_TIME(x / 1000)",
"duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))",
"presto": "FROM_UNIXTIME(x / 1000)",
"spark": "FROM_UNIXTIME(x / 1000)",
},
)
self.validate_all(
"STRFTIME(x, '%y-%-m-%S')",
write={
"bigquery": "TIME_TO_STR(x, '%y-%-m-%S')",
"duckdb": "STRFTIME(x, '%y-%-m-%S')",
"postgres": "TO_CHAR(x, 'YY-FMMM-SS')",
"presto": "DATE_FORMAT(x, '%y-%c-%S')",
"spark": "DATE_FORMAT(x, 'yy-M-ss')",
},
)
self.validate_all(
"STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
write={
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(
"STRPTIME(x, '%y-%-m')",
write={
"bigquery": "STR_TO_TIME(x, '%y-%-m')",
"duckdb": "STRPTIME(x, '%y-%-m')",
"presto": "DATE_PARSE(x, '%y-%c')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yy-M')",
},
)
self.validate_all(
"TO_TIMESTAMP(x)",
write={
"duckdb": "CAST(x AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')",
"hive": "CAST(x AS TIMESTAMP)",
},
)
def test_duckdb(self):
self.validate_all(
"LIST_VALUE(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"REGEXP_MATCHES(x, y)",
write={
"duckdb": "REGEXP_MATCHES(x, y)",
"presto": "REGEXP_LIKE(x, y)",
"hive": "x RLIKE y",
"spark": "x RLIKE y",
},
)
self.validate_all(
"STR_SPLIT(x, 'a')",
write={
"duckdb": "STR_SPLIT(x, 'a')",
"presto": "SPLIT(x, 'a')",
"hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))",
"spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))",
},
)
self.validate_all(
"STRING_TO_ARRAY(x, 'a')",
write={
"duckdb": "STR_SPLIT(x, 'a')",
"presto": "SPLIT(x, 'a')",
"hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))",
"spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))",
},
)
self.validate_all(
"STR_SPLIT_REGEX(x, 'a')",
write={
"duckdb": "STR_SPLIT_REGEX(x, 'a')",
"presto": "REGEXP_SPLIT(x, 'a')",
"hive": "SPLIT(x, 'a')",
"spark": "SPLIT(x, 'a')",
},
)
self.validate_all(
"STRUCT_EXTRACT(x, 'abc')",
write={
"duckdb": "STRUCT_EXTRACT(x, 'abc')",
"presto": 'x."abc"',
"hive": "x.`abc`",
"spark": "x.`abc`",
},
)
self.validate_all(
"STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
write={
"duckdb": "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
"presto": 'x."y"."abc"',
"hive": "x.`y`.`abc`",
"spark": "x.`y`.`abc`",
},
)
self.validate_all(
"QUANTILE(x, 0.5)",
write={
"duckdb": "QUANTILE(x, 0.5)",
"presto": "APPROX_PERCENTILE(x, 0.5)",
"hive": "PERCENTILE(x, 0.5)",
"spark": "PERCENTILE(x, 0.5)",
},
)
self.validate_all(
"CAST(x AS DATE)",
write={
"duckdb": "CAST(x AS DATE)",
"": "CAST(x AS DATE)",
},
)
self.validate_all(
"UNNEST(x)",
read={
"spark": "EXPLODE(x)",
},
write={
"duckdb": "UNNEST(x)",
"spark": "EXPLODE(x)",
},
)
self.validate_all(
"1d",
write={
"duckdb": "1 AS d",
"spark": "1 AS d",
},
)
self.validate_all(
"CAST(1 AS DOUBLE)",
read={
"hive": "1d",
"spark": "1d",
},
)
self.validate_all(
"POWER(CAST(2 AS SMALLINT), 3)",
read={
"hive": "POW(2S, 3)",
"spark": "POW(2S, 3)",
},
)
self.validate_all(
"LIST_SUM(LIST_VALUE(1, 2))",
read={
"spark": "ARRAY_SUM(ARRAY(1, 2))",
},
)
self.validate_all(
"IF(y <> 0, x / y, NULL)",
read={
"bigquery": "SAFE_DIVIDE(x, y)",
},
)
self.validate_all(
"STRUCT_PACK(x := 1, y := '2')",
write={
"duckdb": "STRUCT_PACK(x := 1, y := '2')",
"spark": "STRUCT(x = 1, y = '2')",
},
)
self.validate_all(
"ARRAY_SORT(x)",
write={
"duckdb": "ARRAY_SORT(x)",
"presto": "ARRAY_SORT(x)",
"hive": "SORT_ARRAY(x)",
"spark": "SORT_ARRAY(x)",
},
)
self.validate_all(
"ARRAY_REVERSE_SORT(x)",
write={
"duckdb": "ARRAY_REVERSE_SORT(x)",
"presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)",
"hive": "SORT_ARRAY(x, FALSE)",
"spark": "SORT_ARRAY(x, FALSE)",
},
)
self.validate_all(
"LIST_REVERSE_SORT(x)",
write={
"duckdb": "ARRAY_REVERSE_SORT(x)",
"presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)",
"hive": "SORT_ARRAY(x, FALSE)",
"spark": "SORT_ARRAY(x, FALSE)",
},
)
self.validate_all(
"LIST_SORT(x)",
write={
"duckdb": "ARRAY_SORT(x)",
"presto": "ARRAY_SORT(x)",
"hive": "SORT_ARRAY(x)",
"spark": "SORT_ARRAY(x)",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
},
)
self.validate_all(
"MONTH('2021-03-01')",
write={
"duckdb": "MONTH('2021-03-01')",
"presto": "MONTH('2021-03-01')",
"hive": "MONTH('2021-03-01')",
"spark": "MONTH('2021-03-01')",
},
)

541
tests/dialects/test_hive.py Normal file
View file

@ -0,0 +1,541 @@
from tests.dialects.test_dialect import Validator
class TestHive(Validator):
dialect = "hive"
def test_bits(self):
self.validate_all(
"x & 1",
write={
"duckdb": "x & 1",
"presto": "BITWISE_AND(x, 1)",
"hive": "x & 1",
"spark": "x & 1",
},
)
self.validate_all(
"~x",
write={
"duckdb": "~x",
"presto": "BITWISE_NOT(x)",
"hive": "~x",
"spark": "~x",
},
)
self.validate_all(
"x | 1",
write={
"duckdb": "x | 1",
"presto": "BITWISE_OR(x, 1)",
"hive": "x | 1",
"spark": "x | 1",
},
)
self.validate_all(
"x << 1",
read={
"spark": "SHIFTLEFT(x, 1)",
},
write={
"duckdb": "x << 1",
"presto": "BITWISE_ARITHMETIC_SHIFT_LEFT(x, 1)",
"hive": "x << 1",
"spark": "SHIFTLEFT(x, 1)",
},
)
self.validate_all(
"x >> 1",
read={
"spark": "SHIFTRIGHT(x, 1)",
},
write={
"duckdb": "x >> 1",
"presto": "BITWISE_ARITHMETIC_SHIFT_RIGHT(x, 1)",
"hive": "x >> 1",
"spark": "SHIFTRIGHT(x, 1)",
},
)
self.validate_all(
"x & 1 > 0",
write={
"duckdb": "x & 1 > 0",
"presto": "BITWISE_AND(x, 1) > 0",
"hive": "x & 1 > 0",
"spark": "x & 1 > 0",
},
)
def test_cast(self):
self.validate_all(
"1s",
write={
"duckdb": "CAST(1 AS SMALLINT)",
"presto": "CAST(1 AS SMALLINT)",
"hive": "CAST(1 AS SMALLINT)",
"spark": "CAST(1 AS SHORT)",
},
)
self.validate_all(
"1S",
write={
"duckdb": "CAST(1 AS SMALLINT)",
"presto": "CAST(1 AS SMALLINT)",
"hive": "CAST(1 AS SMALLINT)",
"spark": "CAST(1 AS SHORT)",
},
)
self.validate_all(
"1Y",
write={
"duckdb": "CAST(1 AS TINYINT)",
"presto": "CAST(1 AS TINYINT)",
"hive": "CAST(1 AS TINYINT)",
"spark": "CAST(1 AS BYTE)",
},
)
self.validate_all(
"1L",
write={
"duckdb": "CAST(1 AS BIGINT)",
"presto": "CAST(1 AS BIGINT)",
"hive": "CAST(1 AS BIGINT)",
"spark": "CAST(1 AS LONG)",
},
)
self.validate_all(
"1.0bd",
write={
"duckdb": "CAST(1.0 AS DECIMAL)",
"presto": "CAST(1.0 AS DECIMAL)",
"hive": "CAST(1.0 AS DECIMAL)",
"spark": "CAST(1.0 AS DECIMAL)",
},
)
self.validate_all(
"CAST(1 AS INT)",
read={
"presto": "TRY_CAST(1 AS INT)",
},
write={
"duckdb": "TRY_CAST(1 AS INT)",
"presto": "TRY_CAST(1 AS INTEGER)",
"hive": "CAST(1 AS INT)",
"spark": "CAST(1 AS INT)",
},
)
def test_ddl(self):
self.validate_all(
"CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
write={
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
},
)
def test_lateral_view(self):
self.validate_all(
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
write={
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)",
"hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
},
)
self.validate_all(
"SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
)
self.validate_all(
"SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a",
},
)
self.validate_all(
"SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
},
)
def test_quotes(self):
self.validate_all(
"'\\''",
write={
"duckdb": "''''",
"presto": "''''",
"hive": "'\\''",
"spark": "'\\''",
},
)
self.validate_all(
"'\"x\"'",
write={
"duckdb": "'\"x\"'",
"presto": "'\"x\"'",
"hive": "'\"x\"'",
"spark": "'\"x\"'",
},
)
self.validate_all(
"\"'x'\"",
write={
"duckdb": "'''x'''",
"presto": "'''x'''",
"hive": "'\\'x\\''",
"spark": "'\\'x\\''",
},
)
self.validate_all(
"'\\\\a'",
read={
"presto": "'\\a'",
},
write={
"duckdb": "'\\a'",
"presto": "'\\a'",
"hive": "'\\\\a'",
"spark": "'\\\\a'",
},
)
def test_regex(self):
self.validate_all(
"a RLIKE 'x'",
write={
"duckdb": "REGEXP_MATCHES(a, 'x')",
"presto": "REGEXP_LIKE(a, 'x')",
"hive": "a RLIKE 'x'",
"spark": "a RLIKE 'x'",
},
)
self.validate_all(
"a REGEXP 'x'",
write={
"duckdb": "REGEXP_MATCHES(a, 'x')",
"presto": "REGEXP_LIKE(a, 'x')",
"hive": "a RLIKE 'x'",
"spark": "a RLIKE 'x'",
},
)
def test_time(self):
self.validate_all(
"DATEDIFF(a, b)",
write={
"duckdb": "DATE_DIFF('day', CAST(b AS DATE), CAST(a AS DATE))",
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
"spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
"": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
},
)
self.validate_all(
"""from_unixtime(x, "yyyy-MM-dd'T'HH")""",
write={
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')",
"hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
"spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
},
)
self.validate_all(
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
write={
"duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(
"DATE_ADD('2020-01-01', 1)",
write={
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))",
"hive": "DATE_ADD('2020-01-01', 1)",
"spark": "DATE_ADD('2020-01-01', 1)",
"": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')",
},
)
self.validate_all(
"DATE_SUB('2020-01-01', 1)",
write={
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 * -1 DAY",
"presto": "DATE_ADD('DAY', 1 * -1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))",
"hive": "DATE_ADD('2020-01-01', 1 * -1)",
"spark": "DATE_ADD('2020-01-01', 1 * -1)",
"": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')",
},
)
self.validate_all(
"DATEDIFF(TO_DATE(y), x)",
write={
"duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))",
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
"spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
"": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
},
)
self.validate_all(
"UNIX_TIMESTAMP(x)",
write={
"duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))",
"presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))",
"hive": "UNIX_TIMESTAMP(x)",
"spark": "UNIX_TIMESTAMP(x)",
"": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",
},
)
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
f"{unit}(x)",
write={
"duckdb": f"{unit}(CAST(x AS DATE))",
"presto": f"{unit}(CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE))",
"hive": f"{unit}(TO_DATE(x))",
"spark": f"{unit}(TO_DATE(x))",
},
)
def test_order_by(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
},
)
def test_hive(self):
self.validate_all(
"PERCENTILE(x, 0.5)",
write={
"duckdb": "QUANTILE(x, 0.5)",
"presto": "APPROX_PERCENTILE(x, 0.5)",
"hive": "PERCENTILE(x, 0.5)",
"spark": "PERCENTILE(x, 0.5)",
},
)
self.validate_all(
"APPROX_COUNT_DISTINCT(a)",
write={
"duckdb": "APPROX_COUNT_DISTINCT(a)",
"presto": "APPROX_DISTINCT(a)",
"hive": "APPROX_COUNT_DISTINCT(a)",
"spark": "APPROX_COUNT_DISTINCT(a)",
},
)
self.validate_all(
"ARRAY_CONTAINS(x, 1)",
write={
"duckdb": "ARRAY_CONTAINS(x, 1)",
"presto": "CONTAINS(x, 1)",
"hive": "ARRAY_CONTAINS(x, 1)",
"spark": "ARRAY_CONTAINS(x, 1)",
},
)
self.validate_all(
"SIZE(x)",
write={
"duckdb": "ARRAY_LENGTH(x)",
"presto": "CARDINALITY(x)",
"hive": "SIZE(x)",
"spark": "SIZE(x)",
},
)
self.validate_all(
"LOCATE('a', x)",
write={
"duckdb": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"hive": "LOCATE('a', x)",
"spark": "LOCATE('a', x)",
},
)
self.validate_all(
"LOCATE('a', x, 3)",
write={
"duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"hive": "LOCATE('a', x, 3)",
"spark": "LOCATE('a', x, 3)",
},
)
self.validate_all(
"INITCAP('new york')",
write={
"duckdb": "INITCAP('new york')",
"presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))",
"hive": "INITCAP('new york')",
"spark": "INITCAP('new york')",
},
)
self.validate_all(
"SELECT * FROM x TABLESAMPLE(10) y",
write={
"presto": "SELECT * FROM x AS y TABLESAMPLE(10)",
"hive": "SELECT * FROM x TABLESAMPLE(10) AS y",
"spark": "SELECT * FROM x TABLESAMPLE(10) AS y",
},
)
self.validate_all(
"SELECT SORT_ARRAY(x)",
write={
"duckdb": "SELECT ARRAY_SORT(x)",
"presto": "SELECT ARRAY_SORT(x)",
"hive": "SELECT SORT_ARRAY(x)",
"spark": "SELECT SORT_ARRAY(x)",
},
)
self.validate_all(
"SELECT SORT_ARRAY(x, FALSE)",
read={
"duckdb": "SELECT ARRAY_REVERSE_SORT(x)",
"spark": "SELECT SORT_ARRAY(x, FALSE)",
},
write={
"duckdb": "SELECT ARRAY_REVERSE_SORT(x)",
"presto": "SELECT ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)",
"hive": "SELECT SORT_ARRAY(x, FALSE)",
"spark": "SELECT SORT_ARRAY(x, FALSE)",
},
)
self.validate_all(
"GET_JSON_OBJECT(x, '$.name')",
write={
"presto": "JSON_EXTRACT_SCALAR(x, '$.name')",
"hive": "GET_JSON_OBJECT(x, '$.name')",
"spark": "GET_JSON_OBJECT(x, '$.name')",
},
)
self.validate_all(
"MAP(a, b, c, d)",
write={
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))",
},
)
self.validate_all(
"MAP(a, b)",
write={
"duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))",
"presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)",
"spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))",
},
)
self.validate_all(
"LOG(10)",
write={
"duckdb": "LN(10)",
"presto": "LN(10)",
"hive": "LN(10)",
"spark": "LN(10)",
},
)
self.validate_all(
"LOG(10, 2)",
write={
"duckdb": "LOG(10, 2)",
"presto": "LOG(10, 2)",
"hive": "LOG(10, 2)",
"spark": "LOG(10, 2)",
},
)
self.validate_all(
'ds = "2020-01-01"',
write={
"duckdb": "ds = '2020-01-01'",
"presto": "ds = '2020-01-01'",
"hive": "ds = '2020-01-01'",
"spark": "ds = '2020-01-01'",
},
)
self.validate_all(
"ds = \"1''2\"",
write={
"duckdb": "ds = '1''''2'",
"presto": "ds = '1''''2'",
"hive": "ds = '1\\'\\'2'",
"spark": "ds = '1\\'\\'2'",
},
)
self.validate_all(
"x == 1",
write={
"duckdb": "x = 1",
"presto": "x = 1",
"hive": "x = 1",
"spark": "x = 1",
},
)
self.validate_all(
"x div y",
write={
"duckdb": "CAST(x / y AS INT)",
"presto": "CAST(x / y AS INTEGER)",
"hive": "CAST(x / y AS INT)",
"spark": "CAST(x / y AS INT)",
},
)
self.validate_all(
"COLLECT_LIST(x)",
read={
"presto": "ARRAY_AGG(x)",
},
write={
"duckdb": "ARRAY_AGG(x)",
"presto": "ARRAY_AGG(x)",
"hive": "COLLECT_LIST(x)",
"spark": "COLLECT_LIST(x)",
},
)
self.validate_all(
"COLLECT_SET(x)",
read={
"presto": "SET_AGG(x)",
},
write={
"presto": "SET_AGG(x)",
"hive": "COLLECT_SET(x)",
"spark": "COLLECT_SET(x)",
},
)
self.validate_all(
"SELECT * FROM x TABLESAMPLE(1) AS foo",
read={
"presto": "SELECT * FROM x AS foo TABLESAMPLE(1)",
},
write={
"presto": "SELECT * FROM x AS foo TABLESAMPLE(1)",
"hive": "SELECT * FROM x TABLESAMPLE(1) AS foo",
"spark": "SELECT * FROM x TABLESAMPLE(1) AS foo",
},
)

View file

@ -0,0 +1,79 @@
from tests.dialects.test_dialect import Validator
class TestMySQL(Validator):
dialect = "mysql"
def test_ddl(self):
self.validate_all(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
write={
"mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
"spark": "CREATE TABLE z (a INT) COMMENT 'x'",
},
)
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
def test_introducers(self):
self.validate_all(
"_utf8mb4 'hola'",
read={
"mysql": "_utf8mb4'hola'",
},
write={
"mysql": "_utf8mb4 'hola'",
},
)
def test_binary_literal(self):
self.validate_all(
"SELECT 0xCC",
write={
"mysql": "SELECT b'11001100'",
"spark": "SELECT X'11001100'",
},
)
self.validate_all(
"SELECT 0xz",
write={
"mysql": "SELECT `0xz`",
},
)
self.validate_all(
"SELECT 0XCC",
write={
"mysql": "SELECT 0 AS XCC",
},
)
def test_string_literals(self):
self.validate_all(
'SELECT "2021-01-01" + INTERVAL 1 MONTH',
write={
"mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
},
)
def test_convert(self):
self.validate_all(
"CONVERT(x USING latin1)",
write={
"mysql": "CAST(x AS CHAR CHARACTER SET latin1)",
},
)
self.validate_all(
"CAST(x AS CHAR CHARACTER SET latin1)",
write={
"mysql": "CAST(x AS CHAR CHARACTER SET latin1)",
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 # arbitrary content,,, until end-of-line",
write={
"mysql": "SELECT 1",
},
)

View file

@ -0,0 +1,93 @@
from sqlglot import ParseError, transpile
from tests.dialects.test_dialect import Validator
class TestPostgres(Validator):
dialect = "postgres"
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
write={
"postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))",
write={
"postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))"
},
)
self.validate_all(
"CREATE TABLE products ("
"product_no INT UNIQUE,"
" name TEXT,"
" price DECIMAL CHECK (price > 0),"
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
" CHECK (product_no > 1),"
" CONSTRAINT valid_discount CHECK (price > discounted_price))",
write={
"postgres": "CREATE TABLE products ("
"product_no INT UNIQUE,"
" name TEXT,"
" price DECIMAL CHECK (price > 0),"
" discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0),"
" CHECK (product_no > 1),"
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
},
)
with self.assertRaises(ParseError):
transpile(
"CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres"
)
with self.assertRaises(ParseError):
transpile(
"CREATE TABLE products (price DECIMAL, CHECK price > 1)",
read="postgres",
)
def test_postgres(self):
self.validate_all(
"CREATE TABLE x (a INT SERIAL)",
read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"},
)
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
},
)
self.validate_all(
"SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)",
write={
"postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)",
},
)
self.validate_all(
"SELECT * FROM x FETCH 1 ROW",
write={
"postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
"spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)

View file

@ -0,0 +1,422 @@
from sqlglot import UnsupportedError
from tests.dialects.test_dialect import Validator
class TestPresto(Validator):
dialect = "presto"
def test_cast(self):
self.validate_all(
"CAST(a AS ARRAY(INT))",
write={
"bigquery": "CAST(a AS ARRAY<INT64>)",
"duckdb": "CAST(a AS ARRAY<INT>)",
"presto": "CAST(a AS ARRAY(INTEGER))",
"spark": "CAST(a AS ARRAY<INT>)",
},
)
self.validate_all(
"CAST(a AS VARCHAR)",
write={
"bigquery": "CAST(a AS STRING)",
"duckdb": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)",
},
)
self.validate_all(
"CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
write={
"bigquery": "CAST([1, 2] AS ARRAY<INT64>)",
"duckdb": "CAST(LIST_VALUE(1, 2) AS ARRAY<BIGINT>)",
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
"spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)",
},
)
self.validate_all(
"CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INT,INT))",
write={
"bigquery": "CAST(MAP([1], [1]) AS MAP<INT64, INT64>)",
"duckdb": "CAST(MAP(LIST_VALUE(1), LIST_VALUE(1)) AS MAP<INT, INT>)",
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
},
)
self.validate_all(
"CAST(MAP(ARRAY['a','b','c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INT)))",
write={
"bigquery": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP<STRING, ARRAY<INT64>>)",
"duckdb": "CAST(MAP(LIST_VALUE('a', 'b', 'c'), LIST_VALUE(LIST_VALUE(1), LIST_VALUE(2), LIST_VALUE(3))) AS MAP<TEXT, ARRAY<INT>>)",
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
},
)
self.validate_all(
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
write={
"bigquery": "CAST(x AS TIMESTAMPTZ(9))",
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
"hive": "CAST(x AS TIMESTAMPTZ(9))",
"spark": "CAST(x AS TIMESTAMPTZ(9))",
},
)
def test_regex(self):
self.validate_all(
"REGEXP_LIKE(a, 'x')",
write={
"duckdb": "REGEXP_MATCHES(a, 'x')",
"presto": "REGEXP_LIKE(a, 'x')",
"hive": "a RLIKE 'x'",
"spark": "a RLIKE 'x'",
},
)
self.validate_all(
"SPLIT(x, 'a.')",
write={
"duckdb": "STR_SPLIT(x, 'a.')",
"presto": "SPLIT(x, 'a.')",
"hive": "SPLIT(x, CONCAT('\\\\Q', 'a.'))",
"spark": "SPLIT(x, CONCAT('\\\\Q', 'a.'))",
},
)
self.validate_all(
"REGEXP_SPLIT(x, 'a.')",
write={
"duckdb": "STR_SPLIT_REGEX(x, 'a.')",
"presto": "REGEXP_SPLIT(x, 'a.')",
"hive": "SPLIT(x, 'a.')",
"spark": "SPLIT(x, 'a.')",
},
)
self.validate_all(
"CARDINALITY(x)",
write={
"duckdb": "ARRAY_LENGTH(x)",
"presto": "CARDINALITY(x)",
"hive": "SIZE(x)",
"spark": "SIZE(x)",
},
)
def test_time(self):
self.validate_all(
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
write={
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(
"DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
write={
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(
"DATE_PARSE(x, '%Y-%m-%d')",
write={
"duckdb": "STRPTIME(x, '%Y-%m-%d')",
"presto": "DATE_PARSE(x, '%Y-%m-%d')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
},
)
self.validate_all(
"DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
write={
"duckdb": "STRPTIME(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"hive": "CAST(SUBSTR(x, 1, 10) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(SUBSTR(x, 1, 10), 'yyyy-MM-dd')",
},
)
self.validate_all(
"FROM_UNIXTIME(x)",
write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"presto": "FROM_UNIXTIME(x)",
"hive": "FROM_UNIXTIME(x)",
"spark": "FROM_UNIXTIME(x)",
},
)
self.validate_all(
"TO_UNIXTIME(x)",
write={
"duckdb": "EPOCH(x)",
"presto": "TO_UNIXTIME(x)",
"hive": "UNIX_TIMESTAMP(x)",
"spark": "UNIX_TIMESTAMP(x)",
},
)
self.validate_all(
"DATE_ADD('day', 1, x)",
write={
"duckdb": "x + INTERVAL 1 day",
"presto": "DATE_ADD('day', 1, x)",
"hive": "DATE_ADD(x, 1)",
"spark": "DATE_ADD(x, 1)",
},
)
def test_ddl(self):
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
write={
"presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])",
"hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
"spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)",
},
)
self.validate_all(
"CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y",
write={
"presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y",
"hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y",
"spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y",
},
)
self.validate_all(
"CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))",
write={
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)",
},
)
self.validate_all(
"CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))",
write={
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)",
},
)
self.validate(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
read="presto",
write="presto",
)
def test_quotes(self):
self.validate_all(
"''''",
write={
"duckdb": "''''",
"presto": "''''",
"hive": "'\\''",
"spark": "'\\''",
},
)
self.validate_all(
"'x'",
write={
"duckdb": "'x'",
"presto": "'x'",
"hive": "'x'",
"spark": "'x'",
},
)
self.validate_all(
"'''x'''",
write={
"duckdb": "'''x'''",
"presto": "'''x'''",
"hive": "'\\'x\\''",
"spark": "'\\'x\\''",
},
)
self.validate_all(
"'''x'",
write={
"duckdb": "'''x'",
"presto": "'''x'",
"hive": "'\\'x'",
"spark": "'\\'x'",
},
)
self.validate_all(
"x IN ('a', 'a''b')",
write={
"duckdb": "x IN ('a', 'a''b')",
"presto": "x IN ('a', 'a''b')",
"hive": "x IN ('a', 'a\\'b')",
"spark": "x IN ('a', 'a\\'b')",
},
)
def test_unnest(self):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y)) AS t (a)",
write={
"presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)",
"hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a",
},
)
def test_presto(self):
self.validate_all(
'SELECT a."b" FROM "foo"',
write={
"duckdb": 'SELECT a."b" FROM "foo"',
"presto": 'SELECT a."b" FROM "foo"',
"spark": "SELECT a.`b` FROM `foo`",
},
)
self.validate_all(
"SELECT ARRAY[1, 2]",
write={
"bigquery": "SELECT [1, 2]",
"duckdb": "SELECT LIST_VALUE(1, 2)",
"presto": "SELECT ARRAY[1, 2]",
"spark": "SELECT ARRAY(1, 2)",
},
)
self.validate_all(
"SELECT APPROX_DISTINCT(a) FROM foo",
write={
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"presto": "SELECT APPROX_DISTINCT(a) FROM foo",
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
},
)
self.validate_all(
"SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
write={
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
},
)
self.validate_all(
"SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
write={
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": UnsupportedError,
"spark": UnsupportedError,
},
)
self.validate_all(
"SELECT JSON_EXTRACT(x, '$.name')",
write={
"presto": "SELECT JSON_EXTRACT(x, '$.name')",
"hive": "SELECT GET_JSON_OBJECT(x, '$.name')",
"spark": "SELECT GET_JSON_OBJECT(x, '$.name')",
},
)
self.validate_all(
"SELECT JSON_EXTRACT_SCALAR(x, '$.name')",
write={
"presto": "SELECT JSON_EXTRACT_SCALAR(x, '$.name')",
"hive": "SELECT GET_JSON_OBJECT(x, '$.name')",
"spark": "SELECT GET_JSON_OBJECT(x, '$.name')",
},
)
self.validate_all(
"'\u6bdb'",
write={
"presto": "'\u6bdb'",
"hive": "'\u6bdb'",
"spark": "'\u6bdb'",
},
)
self.validate_all(
"SELECT ARRAY_SORT(x, (left, right) -> -1)",
write={
"duckdb": "SELECT ARRAY_SORT(x)",
"presto": "SELECT ARRAY_SORT(x, (left, right) -> -1)",
"hive": "SELECT SORT_ARRAY(x)",
"spark": "SELECT ARRAY_SORT(x, (left, right) -> -1)",
},
)
self.validate_all(
"SELECT ARRAY_SORT(x)",
write={
"presto": "SELECT ARRAY_SORT(x)",
"hive": "SELECT SORT_ARRAY(x)",
"spark": "SELECT ARRAY_SORT(x)",
},
)
self.validate_all(
"SELECT ARRAY_SORT(x, (left, right) -> -1)",
write={
"hive": UnsupportedError,
},
)
self.validate_all(
"MAP(a, b)",
write={
"hive": UnsupportedError,
"spark": "MAP_FROM_ARRAYS(a, b)",
},
)
self.validate_all(
"MAP(ARRAY(a, b), ARRAY(c, d))",
write={
"hive": "MAP(a, c, b, d)",
"presto": "MAP(ARRAY[a, b], ARRAY[c, d])",
"spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))",
},
)
self.validate_all(
"MAP(ARRAY('a'), ARRAY('b'))",
write={
"hive": "MAP('a', 'b')",
"presto": "MAP(ARRAY['a'], ARRAY['b'])",
"spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))",
},
)
self.validate_all(
"SELECT * FROM UNNEST(ARRAY['7', '14']) AS x",
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14'])",
"presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x",
"hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x",
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x",
},
)
self.validate_all(
"SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)",
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS y",
"presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)",
"hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)",
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)",
},
)
self.validate_all(
"WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n+1 FROM t WHERE n < 100 ) SELECT sum(n) FROM t",
write={
"presto": "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t",
"spark": UnsupportedError,
},
)

View file

@ -0,0 +1,145 @@
from sqlglot import UnsupportedError
from tests.dialects.test_dialect import Validator
class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
self.validate_all(
'x:a:"b c"',
write={
"duckdb": "x['a']['b c']",
"hive": "x['a']['b c']",
"presto": "x['a']['b c']",
"snowflake": "x['a']['b c']",
"spark": "x['a']['b c']",
},
)
self.validate_all(
"SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10",
write={
"bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS LAST LIMIT 10",
"snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10",
},
)
self.validate_all(
"SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1",
write={
"bigquery": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z NULLS LAST) = 1",
"snowflake": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP(1659981729)",
write={
"bigquery": "SELECT UNIX_TO_TIME(1659981729)",
"snowflake": "SELECT TO_TIMESTAMP(1659981729)",
"spark": "SELECT FROM_UNIXTIME(1659981729)",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP(1659981729000, 3)",
write={
"bigquery": "SELECT UNIX_TO_TIME(1659981729000, 'millis')",
"snowflake": "SELECT TO_TIMESTAMP(1659981729000, 3)",
"spark": "SELECT TIMESTAMP_MILLIS(1659981729000)",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP('1659981729')",
write={
"bigquery": "SELECT UNIX_TO_TIME('1659981729')",
"snowflake": "SELECT TO_TIMESTAMP('1659981729')",
"spark": "SELECT FROM_UNIXTIME('1659981729')",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP(1659981729000000000, 9)",
write={
"bigquery": "SELECT UNIX_TO_TIME(1659981729000000000, 'micros')",
"snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)",
"spark": "SELECT TIMESTAMP_MICROS(1659981729000000000)",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
write={
"bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')",
"snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')",
},
)
self.validate_all(
"SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
read={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
},
write={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
},
)
self.validate_all(
"SELECT IFF(TRUE, 'true', 'false')",
write={
"snowflake": "SELECT IFF(TRUE, 'true', 'false')",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
},
)
self.validate_all(
"SELECT ARRAY_AGG(DISTINCT a)",
write={
"spark": "SELECT COLLECT_LIST(DISTINCT a)",
"snowflake": "SELECT ARRAY_AGG(DISTINCT a)",
},
)
self.validate_all(
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write={
"snowflake": UnsupportedError,
},
)
self.validate_all(
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
write={
"snowflake": UnsupportedError,
},
)
self.validate_all(
"SELECT ARRAY_UNION_AGG(a)",
write={
"snowflake": "SELECT ARRAY_UNION_AGG(a)",
},
)
self.validate_all(
"SELECT NVL2(a, b, c)",
write={
"snowflake": "SELECT NVL2(a, b, c)",
},
)
self.validate_all(
"SELECT $$a$$",
write={
"snowflake": "SELECT 'a'",
},
)
self.validate_all(
r"SELECT $$a ' \ \t \x21 z $ $$",
write={
"snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '",
},
)

View file

@ -0,0 +1,226 @@
from tests.dialects.test_dialect import Validator
class TestSpark(Validator):
dialect = "spark"
def test_ddl(self):
self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)",
write={
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)",
},
)
self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:struct<nested_col_a:string, nested_col_b:string>>)",
write={
"bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
"presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))",
"hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)",
"spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)",
},
)
self.validate_all(
"CREATE TABLE db.example_table (col_a array<int>, col_b array<array<int>>)",
write={
"bigquery": "CREATE TABLE db.example_table (col_a ARRAY<INT64>, col_b ARRAY<ARRAY<INT64>>)",
"presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))",
"hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
"spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)",
},
)
self.validate_all(
"CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
write={
"presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])",
"hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
"spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'",
},
)
self.validate_all(
"CREATE TABLE test STORED AS PARQUET AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
},
)
self.validate_all(
"CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
write={
"presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1",
"hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
"spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1",
},
)
self.validate_all(
"""CREATE TABLE blah (col_a INT) COMMENT "Test comment: blah" PARTITIONED BY (date STRING) STORED AS ICEBERG TBLPROPERTIES('x' = '1')""",
write={
"presto": """CREATE TABLE blah (
col_a INTEGER,
date VARCHAR
)
COMMENT='Test comment: blah'
WITH (
PARTITIONED_BY = ARRAY['date'],
FORMAT = 'ICEBERG',
x = '1'
)""",
"hive": """CREATE TABLE blah (
col_a INT
)
COMMENT 'Test comment: blah'
PARTITIONED BY (
date STRING
)
STORED AS ICEBERG
TBLPROPERTIES (
'x' = '1'
)""",
"spark": """CREATE TABLE blah (
col_a INT
)
COMMENT 'Test comment: blah'
PARTITIONED BY (
date STRING
)
STORED AS ICEBERG
TBLPROPERTIES (
'x' = '1'
)""",
},
pretty=True,
)
def test_to_date(self):
self.validate_all(
"TO_DATE(x, 'yyyy-MM-dd')",
write={
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)",
"spark": "TO_DATE(x)",
},
)
self.validate_all(
"TO_DATE(x, 'yyyy')",
write={
"duckdb": "CAST(STRPTIME(x, '%Y') AS DATE)",
"hive": "TO_DATE(x, 'yyyy')",
"presto": "CAST(DATE_PARSE(x, '%Y') AS DATE)",
"spark": "TO_DATE(x, 'yyyy')",
},
)
def test_hint(self):
self.validate_all(
"SELECT /*+ COALESCE(3) */ * FROM x",
write={
"spark": "SELECT /*+ COALESCE(3) */ * FROM x",
},
)
self.validate_all(
"SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
write={
"spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
},
)
def test_spark(self):
self.validate_all(
"ARRAY_SORT(x, (left, right) -> -1)",
write={
"duckdb": "ARRAY_SORT(x)",
"presto": "ARRAY_SORT(x, (left, right) -> -1)",
"hive": "SORT_ARRAY(x)",
"spark": "ARRAY_SORT(x, (left, right) -> -1)",
},
)
self.validate_all(
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
"duckdb": "LIST_VALUE(0, 1, 2)",
"presto": "ARRAY[0, 1, 2]",
"hive": "ARRAY(0, 1, 2)",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST",
},
)
self.validate_all(
"SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
write={
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"presto": "SELECT APPROX_DISTINCT(a) FROM foo",
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
},
)
self.validate_all(
"MONTH('2021-03-01')",
write={
"duckdb": "MONTH(CAST('2021-03-01' AS DATE))",
"presto": "MONTH(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))",
"hive": "MONTH(TO_DATE('2021-03-01'))",
"spark": "MONTH(TO_DATE('2021-03-01'))",
},
)
self.validate_all(
"YEAR('2021-03-01')",
write={
"duckdb": "YEAR(CAST('2021-03-01' AS DATE))",
"presto": "YEAR(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))",
"hive": "YEAR(TO_DATE('2021-03-01'))",
"spark": "YEAR(TO_DATE('2021-03-01'))",
},
)
self.validate_all(
"'\u6bdb'",
write={
"duckdb": "''",
"presto": "''",
"hive": "''",
"spark": "''",
},
)
self.validate_all(
"SELECT LEFT(x, 2), RIGHT(x, 2)",
write={
"duckdb": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"spark": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
},
)
self.validate_all(
"MAP_FROM_ARRAYS(ARRAY(1), c)",
write={
"duckdb": "MAP(LIST_VALUE(1), c)",
"presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
},
)
self.validate_all(
"SELECT ARRAY_SORT(x)",
write={
"duckdb": "SELECT ARRAY_SORT(x)",
"presto": "SELECT ARRAY_SORT(x)",
"hive": "SELECT SORT_ARRAY(x)",
"spark": "SELECT ARRAY_SORT(x)",
},
)

View file

@ -0,0 +1,72 @@
from tests.dialects.test_dialect import Validator
class TestSQLite(Validator):
dialect = "sqlite"
def test_ddl(self):
self.validate_all(
"""
CREATE TABLE "Track"
(
CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"),
FOREIGN KEY ("AlbumId") REFERENCES "Album" ("AlbumId")
ON DELETE NO ACTION ON UPDATE NO ACTION,
FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT,
FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT
)
""",
write={
"sqlite": """CREATE TABLE "Track" (
CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"),
FOREIGN KEY ("AlbumId") REFERENCES "Album"("AlbumId") ON DELETE NO ACTION ON UPDATE NO ACTION,
FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT,
FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT
)""",
},
pretty=True,
)
self.validate_all(
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
read={
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
},
write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
},
)
self.validate_all(
"""CREATE TABLE "x" ("Name" NVARCHAR(200) NOT NULL)""",
write={
"sqlite": """CREATE TABLE "x" ("Name" TEXT(200) NOT NULL)""",
"mysql": "CREATE TABLE `x` (`Name` VARCHAR(200) NOT NULL)",
},
)
def test_sqlite(self):
self.validate_all(
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
write={
"sqlite": 'SELECT CAST("a"."b" AS INTEGER) FROM foo',
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)
self.validate_all(
"EDITDIST3(col1, col2)",
read={
"sqlite": "EDITDIST3(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
},
write={
"sqlite": "EDITDIST3(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
},
)
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
},
)

View file

@ -0,0 +1,8 @@
from tests.dialects.test_dialect import Validator
class TestMySQL(Validator):
dialect = "starrocks"
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")

View file

@ -0,0 +1,62 @@
from tests.dialects.test_dialect import Validator
class TestTableau(Validator):
dialect = "tableau"
def test_tableau(self):
self.validate_all(
"IF x = 'a' THEN y ELSE NULL END",
read={
"presto": "IF(x = 'a', y, NULL)",
},
write={
"presto": "IF(x = 'a', y, NULL)",
"hive": "IF(x = 'a', y, NULL)",
"tableau": "IF x = 'a' THEN y ELSE NULL END",
},
)
self.validate_all(
"IFNULL(a, 0)",
read={
"presto": "COALESCE(a, 0)",
},
write={
"presto": "COALESCE(a, 0)",
"hive": "COALESCE(a, 0)",
"tableau": "IFNULL(a, 0)",
},
)
self.validate_all(
"COUNTD(a)",
read={
"presto": "COUNT(DISTINCT a)",
},
write={
"presto": "COUNT(DISTINCT a)",
"hive": "COUNT(DISTINCT a)",
"tableau": "COUNTD(a)",
},
)
self.validate_all(
"COUNTD((a))",
read={
"presto": "COUNT(DISTINCT(a))",
},
write={
"presto": "COUNT(DISTINCT (a))",
"hive": "COUNT(DISTINCT (a))",
"tableau": "COUNTD((a))",
},
)
self.validate_all(
"COUNT(a)",
read={
"presto": "COUNT(a)",
},
write={
"presto": "COUNT(a)",
"hive": "COUNT(a)",
"tableau": "COUNT(a)",
},
)

514
tests/fixtures/identity.sql vendored Normal file
View file

@ -0,0 +1,514 @@
SUM(1)
SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y
1
1.0
1E2
1E+2
1E-2
1.1E10
1.12e-10
-11.023E7 * 3
(1 * 2) / (3 - 5)
((TRUE))
''
''''
'x'
'\x'
"x"
""
x
x % 1
x < 1
x <= 1
x > 1
x >= 1
x <> 1
x = y OR x > 1
x & 1
x | 1
x ^ 1
~x
x << 1
x >> 1
x >> 1 | 1 & 1 ^ 1
x || y
1 - -1
dec.x + y
a.filter
a.b.c
a.b.c.d
a.b.c.d.e
a.b.c.d.e[0]
a.b.c.d.e[0].f
a[0][0].b.c[1].d.e.f[1][1]
a[0].b[1]
a[0].b.c['d']
a.b.C()
a['x'].b.C()
a.B()
a['x'].C()
int.x
map.x
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
x BETWEEN -1 AND 1
x BETWEEN 'a' || b AND 'c' || d
NOT x IS NULL
x IS TRUE
x IS FALSE
time
zone
ARRAY<TEXT>
CURRENT_DATE
CURRENT_DATE('UTC')
CURRENT_DATE AT TIME ZONE 'UTC'
CURRENT_DATE AT TIME ZONE zone_column
CURRENT_DATE AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Tokio'
ARRAY()
ARRAY(1, 2)
ARRAY_CONTAINS(x, 1)
EXTRACT(x FROM y)
EXTRACT(DATE FROM y)
CONCAT_WS('-', 'a', 'b')
CONCAT_WS('-', 'a', 'b', 'c')
POSEXPLODE("x") AS ("a", "b")
POSEXPLODE("x") AS ("a", "b", "c")
STR_POSITION(x, 'a')
STR_POSITION(x, 'a', 3)
SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)]
x[ORDINAL(1)][SAFE_OFFSET(2)]
x LIKE SUBSTR('abc', 1, 1)
x LIKE y
x LIKE a.y
x LIKE '%y%'
x ILIKE '%y%'
x LIKE '%y%' ESCAPE '\'
x ILIKE '%y%' ESCAPE '\'
1 AS escape
INTERVAL '1' day
INTERVAL '1' month
INTERVAL '1 day'
INTERVAL 2 months
INTERVAL 1 + 3 days
TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY)
DATETIME_DIFF(CURRENT_DATE, 1, DAY)
QUANTILE(x, 0.5)
REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))
REGEXP_LIKE('new york', '.')
REGEXP_SPLIT('new york', '.')
SPLIT('new york', '.')
X((y AS z)).1
(x AS y, y AS z)
REPLACE(1)
DATE(x) = DATE(y)
TIMESTAMP(DATE(x))
TIMESTAMP_TRUNC(COALESCE(time_field, CURRENT_TIMESTAMP()), DAY)
COUNT(DISTINCT CASE WHEN DATE_TRUNC(DATE(time_field), isoweek) = DATE_TRUNC(DATE(time_field2), isoweek) THEN report_id ELSE NULL END)
x[y - 1]
CASE WHEN SUM(x) > 3 THEN 1 END OVER (PARTITION BY x)
SUM(ROW() OVER (PARTITION BY x))
SUM(ROW() OVER (PARTITION BY x + 1))
SUM(ROW() OVER (PARTITION BY x AND y))
(ROW() OVER ())
CASE WHEN (x > 1) THEN 1 ELSE 0 END
CASE (1) WHEN 1 THEN 1 ELSE 0 END
CASE 1 WHEN 1 THEN 1 ELSE 0 END
x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
ADD JAR s3://bucket
ADD JARS s3://bucket, c
ADD FILE s3://file
ADD FILES s3://file, s3://a
ADD ARCHIVE s3://file
ADD ARCHIVES s3://file, s3://a
BEGIN IMMEDIATE TRANSACTION
COMMIT
USE db
NOT 1
NOT NOT 1
SELECT * FROM test
SELECT *, 1 FROM test
SELECT * FROM a.b
SELECT * FROM a.b.c
SELECT * FROM table
SELECT 1
SELECT 1 FROM test
SELECT * FROM a, b, (SELECT 1) AS c
SELECT a FROM test
SELECT 1 AS filter
SELECT SUM(x) AS filter
SELECT 1 AS range FROM test
SELECT 1 AS count FROM test
SELECT 1 AS comment FROM test
SELECT 1 AS numeric FROM test
SELECT 1 AS number FROM test
SELECT t.count
SELECT DISTINCT x FROM test
SELECT DISTINCT x, y FROM test
SELECT DISTINCT TIMESTAMP_TRUNC(time_field, MONTH) AS time_value FROM "table"
SELECT DISTINCT ON (x) x, y FROM z
SELECT DISTINCT ON (x, y + 1) * FROM z
SELECT DISTINCT ON (x.y) * FROM z
SELECT top.x
SELECT TIMESTAMP(DATE_TRUNC(DATE(time_field), MONTH)) AS time_value FROM "table"
SELECT GREATEST((3 + 1), LEAST(3, 4))
SELECT TRANSFORM(a, b -> b) AS x
SELECT AGGREGATE(a, (a, b) -> a + b) AS x
SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x
SELECT LAG(x) OVER (ORDER BY y) AS x
SELECT LEAD(a) OVER (ORDER BY b) AS a
SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> "a" + ("z" - 1))
SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test
SELECT a AS b FROM test
SELECT "a"."b" FROM "a"
SELECT "a".b FROM a
SELECT a.b FROM "a"
SELECT a.b FROM a
SELECT '"hi' AS x FROM x
SELECT 1 AS "|sum" FROM x
SELECT '\"hi' AS x FROM x
SELECT 1 AS b FROM test
SELECT 1 AS "b" FROM test
SELECT 1 + 1 FROM test
SELECT 1 - 1 FROM test
SELECT 1 * 1 FROM test
SELECT 1 % 1 FROM test
SELECT 1 / 1 FROM test
SELECT 1 < 2 FROM test
SELECT 1 <= 2 FROM test
SELECT 1 > 2 FROM test
SELECT 1 >= 2 FROM test
SELECT 1 <> 2 FROM test
SELECT JSON_EXTRACT(x, '$.name')
SELECT JSON_EXTRACT_SCALAR(x, '$.name')
SELECT x LIKE '%x%' FROM test
SELECT * FROM test LIMIT 100
SELECT * FROM test LIMIT 100 OFFSET 200
SELECT * FROM test FETCH FIRST 1 ROWS ONLY
SELECT * FROM test FETCH NEXT 1 ROWS ONLY
SELECT (1 > 2) AS x FROM test
SELECT NOT (1 > 2) FROM test
SELECT 1 + 2 AS x FROM test
SELECT a, b, 1 < 1 FROM test
SELECT a FROM test WHERE NOT FALSE
SELECT a FROM test WHERE a = 1
SELECT a FROM test WHERE a = 1 AND b = 2
SELECT a FROM test WHERE a IN (SELECT b FROM z)
SELECT a FROM test WHERE a IN ((SELECT 1), 2)
SELECT * FROM x WHERE y IN ((SELECT 1) EXCEPT (SELECT 2))
SELECT * FROM x WHERE y IN (SELECT 1 UNION SELECT 2)
SELECT * FROM x WHERE y IN ((SELECT 1 UNION SELECT 2))
SELECT * FROM x WHERE y IN (WITH z AS (SELECT 1) SELECT * FROM z)
SELECT a FROM test WHERE (a > 1)
SELECT a FROM test WHERE a > (SELECT 1 FROM x GROUP BY y)
SELECT a FROM test WHERE EXISTS(SELECT 1)
SELECT a FROM test WHERE EXISTS(SELECT * FROM x UNION SELECT * FROM Y) OR TRUE
SELECT a FROM test WHERE TRUE OR NOT EXISTS(SELECT * FROM x)
SELECT a AS any, b AS some, c AS all, d AS exists FROM test WHERE a = ANY (SELECT 1)
SELECT a FROM test WHERE a > ALL (SELECT 1)
SELECT a FROM test WHERE (a, b) IN (SELECT 1, 2)
SELECT a FROM test ORDER BY a
SELECT a FROM test ORDER BY a, b
SELECT x FROM tests ORDER BY a DESC, b DESC, c
SELECT a FROM test ORDER BY a > 1
SELECT * FROM test ORDER BY DATE DESC, TIMESTAMP DESC
SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l
SELECT * FROM test CLUSTER BY y
SELECT * FROM test CLUSTER BY y
SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()
SELECT a, b FROM test GROUP BY 1
SELECT a, b FROM test GROUP BY a
SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2
SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 ORDER BY a
SELECT a, b FROM test WHERE a = 1 GROUP BY CASE 1 WHEN 1 THEN 1 END
SELECT a FROM test GROUP BY GROUPING SETS (())
SELECT a FROM test GROUP BY GROUPING SETS (x, ())
SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q)
SELECT a FROM test GROUP BY CUBE (x)
SELECT a FROM test GROUP BY ROLLUP (x)
SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z)
SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test
SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END
SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a']
SELECT CASE 1 + 2 WHEN 1 THEN 1 ELSE 2 END
SELECT CASE TEST(1) + x[0] WHEN 1 THEN 1 ELSE 2 END
SELECT CASE x[0] WHEN 1 THEN 1 ELSE 2 END
SELECT CASE a.b WHEN 1 THEN 1 ELSE 2 END
SELECT CASE CASE x > 1 WHEN TRUE THEN 1 END WHEN 1 THEN 1 ELSE 2 END
SELECT a FROM (SELECT a FROM test) AS x
SELECT a FROM (SELECT a FROM (SELECT a FROM test) AS y) AS x
SELECT a FROM test WHERE a IN (1, 2, 3) OR b BETWEEN 1 AND 4
SELECT a FROM test AS x TABLESAMPLE(BUCKET 1 OUT OF 5)
SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5)
SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON x)
SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON RAND())
SELECT a FROM test TABLESAMPLE(0.1 PERCENT)
SELECT a FROM test TABLESAMPLE(100)
SELECT a FROM test TABLESAMPLE(100 ROWS)
SELECT a FROM test TABLESAMPLE BERNOULLI (50)
SELECT a FROM test TABLESAMPLE SYSTEM (75)
SELECT ABS(a) FROM test
SELECT AVG(a) FROM test
SELECT CEIL(a) FROM test
SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test
SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test
SELECT FIRST(a) FROM test
SELECT GREATEST(a, b, c) FROM test
SELECT LAST(a) FROM test
SELECT LN(a) FROM test
SELECT LOG10(a) FROM test
SELECT MAX(a) FROM test
SELECT MIN(a) FROM test
SELECT POWER(a, 2) FROM test
SELECT QUANTILE(a, 0.95) FROM test
SELECT ROUND(a) FROM test
SELECT ROUND(a, 2) FROM test
SELECT SUM(a) FROM test
SELECT SQRT(a) FROM test
SELECT STDDEV(a) FROM test
SELECT STDDEV_POP(a) FROM test
SELECT STDDEV_SAMP(a) FROM test
SELECT VARIANCE(a) FROM test
SELECT VARIANCE_POP(a) FROM test
SELECT CAST(a AS INT) FROM test
SELECT CAST(a AS DATETIME) FROM test
SELECT CAST(a AS VARCHAR) FROM test
SELECT CAST(a < 1 AS INT) FROM test
SELECT CAST(a IS NULL AS INT) FROM test
SELECT COUNT(CAST(1 < 2 AS INT)) FROM test
SELECT COUNT(CASE WHEN CAST(1 < 2 AS BOOLEAN) THEN 1 END) FROM test
SELECT CAST(a AS DECIMAL) FROM test
SELECT CAST(a AS DECIMAL(1)) FROM test
SELECT CAST(a AS DECIMAL(1, 2)) FROM test
SELECT CAST(a AS MAP<INT, INT>) FROM test
SELECT CAST(a AS TIMESTAMP) FROM test
SELECT CAST(a AS DATE) FROM test
SELECT CAST(a AS ARRAY<INT>) FROM test
SELECT TRY_CAST(a AS INT) FROM test
SELECT COALESCE(a, b, c) FROM test
SELECT IFNULL(a, b) FROM test
SELECT ANY_VALUE(a) FROM test
SELECT 1 FROM a JOIN b ON a.x = b.x
SELECT 1 FROM a JOIN b AS c ON a.x = b.x
SELECT 1 FROM a INNER JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x
SELECT 1 FROM a CROSS JOIN b ON a.x = b.x
SELECT 1 FROM a JOIN b USING (x)
SELECT 1 FROM a JOIN b USING (x, y, z)
SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2
SELECT 1 FROM a UNION SELECT 2 FROM b
SELECT 1 FROM a UNION ALL SELECT 2 FROM b
SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar
SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar
SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar
SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar
SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar
SELECT 1 UNION ALL SELECT 2
SELECT 1 EXCEPT SELECT 2
SELECT 1 EXCEPT SELECT 2
SELECT 1 INTERSECT SELECT 2
SELECT 1 INTERSECT SELECT 2
SELECT 1 AS delete, 2 AS alter
SELECT * FROM (x)
SELECT * FROM ((x))
SELECT * FROM ((SELECT 1))
SELECT * FROM (SELECT 1) AS x
SELECT * FROM (SELECT 1 UNION SELECT 2) AS x
SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x
SELECT * FROM (SELECT 1 UNION ALL SELECT 2)
SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b)
SELECT * FROM ((SELECT 1) AS a(b))
SELECT * FROM x AS y(a, b)
SELECT * EXCEPT (a, b)
SELECT * REPLACE (a AS b, b AS C)
SELECT * REPLACE (a + 1 AS b, b AS C)
SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C)
SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C)
SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals)
WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2
WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
(SELECT 1) UNION (SELECT 2)
(SELECT 1) UNION SELECT 2
SELECT 1 UNION (SELECT 2)
(SELECT 1) ORDER BY x LIMIT 1 OFFSET 1
(SELECT 1 UNION SELECT 2) UNION (SELECT 2 UNION ALL SELECT 3)
(SELECT 1 UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1
(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC
(SELECT 1 UNION SELECT 2) SORT BY z
(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z
(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x
SELECT 1 UNION (SELECT 2) ORDER BY x
(SELECT 1) UNION SELECT 2 ORDER BY x
SELECT * FROM (((SELECT 1) UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1)
SELECT * FROM ((SELECT 1 AS x) CROSS JOIN (SELECT 2 AS y)) AS z
((SELECT 1) EXCEPT (SELECT 2))
VALUES (1) UNION SELECT * FROM x
WITH a AS (SELECT 1) SELECT a.* FROM a
WITH a AS (SELECT 1), b AS (SELECT 2) SELECT a.*, b.* FROM a CROSS JOIN b
WITH a AS (WITH b AS (SELECT 1 AS x) SELECT b.x FROM b) SELECT a.x FROM a
WITH RECURSIVE T(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t
WITH RECURSIVE T(n, m) AS (VALUES (1, 2) UNION ALL SELECT n + 1, n + 2 FROM t) SELECT SUM(n) FROM t
WITH baz AS (SELECT 1 AS col) UPDATE bar SET cid = baz.col1 FROM baz
SELECT * FROM (WITH y AS (SELECT 1 AS z) SELECT z FROM y) AS x
SELECT RANK() OVER () FROM x
SELECT RANK() OVER () AS y FROM x
SELECT RANK() OVER (PARTITION BY a) FROM x
SELECT RANK() OVER (PARTITION BY a, b) FROM x
SELECT RANK() OVER (ORDER BY a) FROM x
SELECT RANK() OVER (ORDER BY a, b) FROM x
SELECT RANK() OVER (PARTITION BY a ORDER BY a) FROM x
SELECT RANK() OVER (PARTITION BY a, b ORDER BY a, b DESC) FROM x
SELECT SUM(x) OVER (PARTITION BY a) AS y FROM x
SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND CURRENT ROW)
SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2' DAYS FOLLOWING)
SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND UNBOUNDED FOLLOWING)
SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND PRECEDING)
SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3)
SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING)
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y
SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC)
SELECT SUM(x) FILTER(WHERE x > 1)
SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y)
SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
SELECT a['1'], b[0], x.c[0], "x".d['1'] FROM x
SELECT ARRAY(1, 2, 3) FROM x
SELECT ARRAY(ARRAY(1), ARRAY(2)) FROM x
SELECT MAP[ARRAY(1), ARRAY(2)] FROM x
SELECT MAP(ARRAY(1), ARRAY(2)) FROM x
SELECT MAX(ARRAY(1, 2, 3)) FROM x
SELECT ARRAY(ARRAY(0))[0][0] FROM x
SELECT MAP[ARRAY('x'), ARRAY(0)]['x'] FROM x
SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores)
SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) AS score
SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score
SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score, name
SELECT student, score FROM tests LATERAL VIEW OUTER EXPLODE(scores) t AS score, name
SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf
SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf AS col0, col1, col2
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b)
SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score)
SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score)
CREATE TABLE a.b AS SELECT 1
CREATE TABLE a.b AS SELECT a FROM a.c
CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE VIEW x AS SELECT a FROM b
CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b
CREATE OR REPLACE VIEW x AS SELECT *
CREATE OR REPLACE TEMPORARY VIEW x AS SELECT *
CREATE TEMPORARY VIEW x AS SELECT a FROM d
CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d
CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y
CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3))
CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3))
CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3))
CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
CREATE TABLE z WITH (FORMAT='ORC', x = '2') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1
CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x = '2') AS SELECT 1
CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT, y INT))
CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT)) AS SELECT 1
CREATE TABLE z AS (WITH cte AS (SELECT 1) SELECT * FROM cte)
CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
CREATE TABLE z (a INT UNIQUE)
CREATE TABLE z (a INT AUTO_INCREMENT)
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
CREATE FUNCTION f AS 'g'
CREATE INDEX abc ON t (a)
CREATE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX abc ON t (a, b, b)
CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
CACHE TABLE x
CACHE LAZY TABLE x
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value')
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1
CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a
CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2')
INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y
INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y
ALTER TYPE electronic_mail RENAME TO email
ANALYZE a.y
DELETE FROM x WHERE y > 1
DELETE FROM y
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
DROP TABLE IF EXISTS a.b
DROP VIEW a
DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
INSERT INTO x WITH y AS (SELECT 1) SELECT * FROM y
INSERT INTO x.z IF EXISTS SELECT * FROM y
INSERT INTO x VALUES (1, 'a', 2.0)
INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x)
INSERT INTO y (a, b, c) SELECT a, b, c FROM x
INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y
INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y
SELECT 1 FROM PARQUET_SCAN('/x/y/*') AS y
UNCACHE TABLE x
UNCACHE TABLE IF EXISTS x
UPDATE tbl_name SET foo = 123
UPDATE tbl_name SET foo = 123, bar = 345
UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234
UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234
TRUNCATE TABLE x
OPTIMIZE TABLE y
WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a
WITH a AS (SELECT * FROM b) UPDATE a SET col = 1
WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a
WITH a AS (SELECT * FROM b) DELETE FROM a
WITH a AS (SELECT * FROM b) CACHE TABLE a
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z

View file

@ -0,0 +1,42 @@
SELECT 1 AS x, 2 AS y
UNION ALL
SELECT 1 AS x, 2 AS y;
WITH _e_0 AS (
SELECT
1 AS x,
2 AS y
)
SELECT
*
FROM _e_0
UNION ALL
SELECT
*
FROM _e_0;
SELECT x.id
FROM (
SELECT *
FROM x AS x
JOIN y AS y
ON x.id = y.id
) AS x
JOIN (
SELECT *
FROM x AS x
JOIN y AS y
ON x.id = y.id
) AS y
ON x.id = y.id;
WITH _e_0 AS (
SELECT
*
FROM x AS x
JOIN y AS y
ON x.id = y.id
)
SELECT
x.id
FROM "_e_0" AS x
JOIN "_e_0" AS y
ON x.id = y.id;

View file

@ -0,0 +1,11 @@
--------------------------------------
-- Multi Table Selects
--------------------------------------
SELECT * FROM x AS x, y AS y WHERE x.a = y.a;
SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a;
SELECT * FROM x AS x, y AS y WHERE x.a = y.a AND x.a = 1 and y.b = 1;
SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a AND x.a = 1 AND y.b = 1;
SELECT * FROM x AS x, y AS y WHERE x.a > y.a;
SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a > y.a;

View file

@ -0,0 +1,20 @@
SELECT * FROM x AS x, y AS y2;
SELECT * FROM (SELECT * FROM x AS x) AS x, (SELECT * FROM y AS y) AS y2;
SELECT * FROM x AS x WHERE x = 1;
SELECT * FROM x AS x WHERE x = 1;
SELECT * FROM x AS x JOIN y AS y;
SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y AS y) AS y;
SELECT * FROM (SELECT 1) AS x JOIN y AS y;
SELECT * FROM (SELECT 1) AS x JOIN (SELECT * FROM y AS y) AS y;
SELECT * FROM x AS x JOIN (SELECT * FROM y) AS y;
SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y) AS y;
WITH y AS (SELECT *) SELECT * FROM x AS x;
WITH y AS (SELECT *) SELECT * FROM x AS x;
WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y;
WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y;

41
tests/fixtures/optimizer/normalize.sql vendored Normal file
View file

@ -0,0 +1,41 @@
(A OR B) AND (B OR C) AND (E OR F);
(A OR B) AND (B OR C) AND (E OR F);
(A AND B) OR (B AND C AND D);
(A OR C) AND (A OR D) AND B;
(A OR B) AND (A OR C) AND (A OR D) AND (B OR C) AND (B OR D) AND B;
(A OR C) AND (A OR D) AND B;
(A AND E) OR (B AND C) OR (D AND (E OR F));
(A OR B OR D) AND (A OR C OR D) AND (B OR D OR E) AND (B OR E OR F) AND (C OR D OR E) AND (C OR E OR F);
(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q);
(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q);
NOT NOT NOT (A OR B);
NOT A AND NOT B;
A OR B;
A OR B;
A AND (B AND C);
A AND B AND C;
A OR (B AND C);
(A OR B) AND (A OR C);
(A AND B) OR C;
(A OR C) AND (B OR C);
A OR (B OR (C AND D));
(A OR B OR C) AND (A OR B OR D);
A OR ((((B OR C) AND (B OR D)) OR C) AND (((B OR C) AND (B OR D)) OR D));
(A OR B OR C) AND (A OR B OR D);
(A AND B) OR (C AND D);
(A OR C) AND (A OR D) AND (B OR C) AND (B OR D);
(A AND B) OR (C OR (D AND E));
(A OR C OR D) AND (A OR C OR E) AND (B OR C OR D) AND (B OR C OR E);

View file

@ -0,0 +1,20 @@
SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a;
SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = 1 AND y.a = z.a;
SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a;
SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a;
SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a;
SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a;
SELECT * FROM x LEFT JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a;
SELECT * FROM x JOIN z ON x.a = z.a AND TRUE LEFT JOIN y ON y.a = 1 AND y.a = z.a;
SELECT * FROM x INNER JOIN z;
SELECT * FROM x JOIN z;
SELECT * FROM x LEFT OUTER JOIN z;
SELECT * FROM x LEFT JOIN z;
SELECT * FROM x CROSS JOIN z;
SELECT * FROM x CROSS JOIN z;

148
tests/fixtures/optimizer/optimizer.sql vendored Normal file
View file

@ -0,0 +1,148 @@
SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
SELECT
"z"."a" AS "a",
"q"."m" AS "m"
FROM (
SELECT
"z"."a" AS "a"
FROM "z" AS "z"
) AS "z"
LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m";
SELECT x FROM UNNEST([1, 2]) AS q(x, y);
SELECT
"q"."x" AS "x"
FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y");
WITH cte AS (
(
SELECT
a
FROM
x
)
UNION ALL
(
SELECT
a
FROM
y
)
)
SELECT
*
FROM
cte;
WITH "cte" AS (
(
SELECT
"x"."a" AS "a"
FROM "x" AS "x"
)
UNION ALL
(
SELECT
"y"."a" AS "a"
FROM "y" AS "y"
)
)
SELECT
"cte"."a" AS "a"
FROM "cte";
WITH cte1 AS (
SELECT a
FROM x
), cte2 AS (
SELECT a + 1 AS a
FROM cte1
)
SELECT
a
FROM cte1
UNION ALL
SELECT
a
FROM cte2;
WITH "cte1" AS (
SELECT
"x"."a" AS "a"
FROM "x" AS "x"
), "cte2" AS (
SELECT
"cte1"."a" + 1 AS "a"
FROM "cte1"
)
SELECT
"cte1"."a" AS "a"
FROM "cte1"
UNION ALL
SELECT
"cte2"."a" AS "a"
FROM "cte2";
SELECT a, SUM(b)
FROM (
SELECT x.a, y.b
FROM x, y
WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a
) d
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a;
SELECT
"d"."a" AS "a",
SUM("d"."b") AS "_col_1"
FROM (
SELECT
"x"."a" AS "a",
"y"."b" AS "b"
FROM (
SELECT
"x"."a" AS "a"
FROM "x" AS "x"
WHERE
"x"."a" > 1
) AS "x"
LEFT JOIN (
SELECT
MAX("y"."b") AS "_col_0",
"y"."a" AS "_u_1"
FROM "y" AS "y"
GROUP BY
"y"."a"
) AS "_u_0"
ON "x"."a" = "_u_0"."_u_1"
JOIN (
SELECT
"y"."a" AS "a",
"y"."b" AS "b"
FROM "y" AS "y"
) AS "y"
ON "x"."a" = "y"."a"
WHERE
"_u_0"."_col_0" >= 0
AND NOT "_u_0"."_u_1" IS NULL
) AS "d"
GROUP BY
"d"."a";
(SELECT a FROM x) LIMIT 1;
(
SELECT
"x"."a" AS "a"
FROM "x" AS "x"
)
LIMIT 1;
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
(
SELECT
"x"."b" AS "b"
FROM "x" AS "x"
UNION
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
)
LIMIT 1;

View file

@ -0,0 +1,32 @@
SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND y.a = 1;
SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON y.a = 1 WHERE TRUE AND TRUE AND TRUE;
WITH x AS (SELECT y.a FROM y) SELECT * FROM x WHERE x.a = 1;
WITH x AS (SELECT y.a FROM y WHERE y.a = 1) SELECT * FROM x WHERE TRUE;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE y.a = 1 OR (x.a = 1 AND x.b = 1);
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = 1 AND x.b = 1) OR y.a = 1;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.a;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b;
SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1;
SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE;
SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1 or x.c = 2;
SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1 OR x.b * 1 = 2) AS x WHERE TRUE;
SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND (x.c = 1 OR y.c = 1);
SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE);
SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a;
SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE;
SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE;

View file

@ -0,0 +1,41 @@
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT 1 FROM (SELECT * FROM x) WHERE b = 2;
SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2;
SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q;
SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q;
SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b;
SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2;
SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2;
SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2;
SELECT a FROM (SELECT DISTINCT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
WITH y AS (SELECT * FROM x) SELECT a FROM y;
WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y;
WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q;
WITH z AS (SELECT x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q;
WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z;
SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a);
SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0";
SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a);
SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0";

View file

@ -0,0 +1,233 @@
--------------------------------------
-- Qualify columns
--------------------------------------
SELECT a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT a FROM x AS z;
SELECT z.a AS a FROM x AS z;
SELECT a AS a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT x.a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT x.a AS a FROM x;
SELECT x.a AS a FROM x AS x;
SELECT a AS b FROM x;
SELECT x.a AS b FROM x AS x;
SELECT 1, 2 FROM x;
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x;
SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x;
SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x;
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT a AS j, b FROM x ORDER BY j;
SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j;
SELECT a AS j, b FROM x GROUP BY j;
SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a;
SELECT a, b FROM x GROUP BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x ORDER BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b;
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a);
SELECT a AS a, b FROM x ORDER BY a;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
SELECT a, b FROM x ORDER BY a;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
SELECT a FROM x ORDER BY b;
SELECT x.a AS a FROM x AS x ORDER BY x.b;
# dialect: bigquery
SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1;
SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1;
# dialect: bigquery
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
--------------------------------------
-- Derived tables
--------------------------------------
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y(a);
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS y(c);
SELECT y.c AS c FROM (SELECT x.a AS c, x.b AS b FROM x AS x) AS y;
SELECT a FROM (SELECT a FROM x AS x) y;
SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y;
SELECT a FROM (SELECT a AS a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT a FROM (SELECT a FROM (SELECT a FROM x));
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1";
SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
--------------------------------------
-- Joins
--------------------------------------
SELECT a, c FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT a, c FROM x, y;
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
--------------------------------------
-- Unions
--------------------------------------
SELECT a FROM x UNION SELECT a FROM x;
SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x;
SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x;
SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x;
SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0";
--------------------------------------
-- Subqueries
--------------------------------------
SELECT a FROM x WHERE b IN (SELECT c FROM y);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
SELECT (SELECT c FROM y) FROM x;
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x;
SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y));
SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0");
--------------------------------------
-- Correlated subqueries
--------------------------------------
SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = x.a);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a);
SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = a);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a);
SELECT a FROM x WHERE b IN (SELECT b FROM y AS x);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
# dialect: bigquery
SELECT aa FROM x, UNNEST(a) AS aa;
SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa;
SELECT aa FROM x, UNNEST(a) AS t(aa);
SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa);
--------------------------------------
-- Expand *
--------------------------------------
SELECT * FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x;
SELECT x.* FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x;
SELECT * FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT x.* FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT x.*, y.* FROM x JOIN y ON x.b = y.b;
SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT a FROM (SELECT * FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
SELECT * FROM (SELECT a FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
--------------------------------------
-- CTEs
--------------------------------------
WITH z AS (SELECT x.a AS a FROM x) SELECT z.a AS a FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z;
WITH z(a) AS (SELECT a FROM x) SELECT * FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z;
WITH z AS (SELECT a FROM x) SELECT * FROM z as q;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT q.a AS a FROM z AS q;
WITH z AS (SELECT a FROM x) SELECT * FROM z;
WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z;
WITH z AS (SELECT a FROM x), q AS (SELECT * FROM z) SELECT * FROM q;
WITH z AS (SELECT x.a AS a FROM x AS x), q AS (SELECT z.a AS a FROM z) SELECT q.a AS a FROM q;
WITH z AS (SELECT * FROM x) SELECT * FROM z UNION SELECT * FROM z;
WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT z.a AS a, z.b AS b FROM z UNION SELECT z.a AS a, z.b AS b FROM z;
WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q;
WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q;
WITH z AS ((SELECT b FROM x UNION ALL SELECT b FROM y) ORDER BY b) SELECT * FROM z;
WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) ORDER BY b) SELECT z.b AS b FROM z;
--------------------------------------
-- Except and Replace
--------------------------------------
SELECT * REPLACE(a AS d) FROM x;
SELECT x.a AS d, x.b AS b FROM x AS x;
SELECT * EXCEPT(b) REPLACE(a AS d) FROM x;
SELECT x.a AS d FROM x AS x;
SELECT x.* EXCEPT(a), y.* FROM x, y;
SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y;
SELECT * EXCEPT(a) FROM x;
SELECT x.b AS b FROM x AS x;
--------------------------------------
-- Using
--------------------------------------
SELECT x.b FROM x JOIN y USING (b);
SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b;
SELECT x.b FROM x JOIN y USING (b) JOIN z USING (b);
SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b;
SELECT b FROM x AS x2 JOIN y AS y2 USING (b);
SELECT COALESCE(x2.b, y2.b) AS b FROM x AS x2 JOIN y AS y2 ON x2.b = y2.b;
SELECT b FROM x JOIN y USING (b) WHERE b = 1 and y.b = 2;
SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALESCE(x.b, y.b) = 1 AND y.b = 2;
SELECT b FROM x JOIN y USING (b) JOIN z USING (b);
SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b;

View file

@ -0,0 +1,14 @@
SELECT a FROM zz;
SELECT * FROM zz;
SELECT z.a FROM x;
SELECT z.* FROM x;
SELECT x FROM x;
INSERT INTO x VALUES (1, 2);
SELECT a FROM x AS z JOIN y AS z;
WITH z AS (SELECT * FROM x) SELECT * FROM x AS z;
SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c);
SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a;
SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a;
SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c;
SELECT x.a FROM x JOIN y USING (a);
SELECT a, SUM(b) FROM x GROUP BY 3;

View file

@ -0,0 +1,17 @@
SELECT 1 FROM z;
SELECT 1 FROM c.db.z AS z;
SELECT 1 FROM y.z;
SELECT 1 FROM c.y.z AS z;
SELECT 1 FROM x.y.z;
SELECT 1 FROM x.y.z AS z;
SELECT 1 FROM x.y.z AS z;
SELECT 1 FROM x.y.z AS z;
WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a;
WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a;
SELECT (SELECT y.c FROM y AS y) FROM x;
SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x;

View file

@ -0,0 +1,8 @@
SELECT a FROM x;
SELECT "a" FROM "x";
SELECT "a" FROM "x";
SELECT "a" FROM "x";
SELECT x.a AS a FROM db.x;
SELECT "x"."a" AS "a" FROM "db"."x";

350
tests/fixtures/optimizer/simplify.sql vendored Normal file
View file

@ -0,0 +1,350 @@
--------------------------------------
-- Conditions
--------------------------------------
x AND x;
x;
y OR y;
y;
x AND NOT x;
FALSE;
x OR NOT x;
TRUE;
1 AND TRUE;
TRUE;
TRUE AND TRUE;
TRUE;
1 AND TRUE AND 1 AND 1;
TRUE;
TRUE AND FALSE;
FALSE;
FALSE AND FALSE;
FALSE;
FALSE AND TRUE AND TRUE;
FALSE;
x > y OR FALSE;
x > y;
FALSE OR x = y;
x = y;
1 = 1;
TRUE;
1.0 = 1;
TRUE;
'x' = 'y';
FALSE;
'x' = 'x';
TRUE;
NULL AND TRUE;
NULL;
NULL AND NULL;
NULL;
NULL OR TRUE;
TRUE;
NULL OR NULL;
NULL;
FALSE OR NULL;
NULL;
NOT TRUE;
FALSE;
NOT FALSE;
TRUE;
NULL = NULL;
NULL;
NOT (NOT TRUE);
TRUE;
a AND (b OR b);
a AND b;
a AND (b AND b);
a AND b;
--------------------------------------
-- Absorption
--------------------------------------
(A OR B) AND (C OR NOT A);
(A OR B) AND (C OR NOT A);
A AND (A OR B);
A;
A AND D AND E AND (B OR A);
A AND D AND E;
D AND A AND E AND (B OR A);
A AND D AND E;
(A OR B) AND A;
A;
C AND D AND (A OR B) AND E AND F AND A;
A AND C AND D AND E AND F;
A OR (A AND B);
A;
(A AND B) OR A;
A;
A AND (NOT A OR B);
A AND B;
(NOT A OR B) AND A;
A AND B;
A OR (NOT A AND B);
A OR B;
(A OR C) AND ((A OR C) OR B);
A OR C;
(A OR C) AND (A OR B OR C);
A OR C;
--------------------------------------
-- Elimination
--------------------------------------
(A AND B) OR (A AND NOT B);
A;
(A AND B) OR (NOT A AND B);
B;
(A AND NOT B) OR (A AND B);
A;
(NOT A AND B) OR (A AND B);
B;
(A OR B) AND (A OR NOT B);
A;
(A OR B) AND (NOT A OR B);
B;
(A OR NOT B) AND (A OR B);
A;
(NOT A OR B) AND (A OR B);
B;
(NOT A OR NOT B) AND (NOT A OR B);
NOT A;
(NOT A OR NOT B) AND (NOT A OR NOT NOT B);
NOT A;
E OR (A AND B) OR C OR D OR (A AND NOT B);
A OR C OR D OR E;
--------------------------------------
-- Associativity
--------------------------------------
(A AND B) AND C;
A AND B AND C;
A AND (B AND C);
A AND B AND C;
(A OR B) OR C;
A OR B OR C;
A OR (B OR C);
A OR B OR C;
((A AND B) AND C) AND D;
A AND B AND C AND D;
(((((A) AND B)) AND C)) AND D;
A AND B AND C AND D;
--------------------------------------
-- Comparison and Pruning
--------------------------------------
A AND D AND B AND E AND F AND G AND E AND A;
A AND B AND D AND E AND F AND G;
A AND NOT B AND C AND B;
FALSE;
(a AND b AND c AND d) AND (d AND c AND b AND a);
a AND b AND c AND d;
(c AND (a AND b)) AND ((b AND a) AND c);
a AND b AND c;
(A AND B AND C) OR (C AND B AND A);
A AND B AND C;
--------------------------------------
-- Where removal
--------------------------------------
SELECT x WHERE TRUE;
SELECT x;
--------------------------------------
-- Parenthesis removal
--------------------------------------
(TRUE);
TRUE;
(FALSE);
FALSE;
(FALSE OR TRUE);
TRUE;
TRUE OR (((FALSE) OR (TRUE)) OR FALSE);
TRUE;
(NOT FALSE) AND (NOT TRUE);
FALSE;
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
TRUE;
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
TRUE;
(('a' = 'a') AND TRUE and NOT FALSE);
TRUE;
--------------------------------------
-- Literals
--------------------------------------
1 + 1;
2;
0.06 + 0.01;
0.07;
0.06 + 1;
1.06;
1.2E+1 + 15E-3;
12.015;
1.2E1 + 15E-3;
12.015;
1 - 2;
-1;
-1 + 3;
2;
-(-1);
1;
0.06 - 0.01;
0.05;
3 * 4;
12;
3.0 * 9;
27.0;
0.03 * 0.73;
0.0219;
1 / 3;
0;
20.0 / 6;
3.333333333333333333333333333;
10 / 5;
2;
(1.0 * 3) * 4 - 2 * (5 / 2);
8.0;
6 - 2 + 4 * 2 + a;
12 + a;
a + 1 + 1 + 2;
a + 4;
a + (1 + 1) + (10);
a + 12;
5 + 4 * 3;
17;
1 < 2;
TRUE;
2 <= 2;
TRUE;
2 >= 2;
TRUE;
2 > 1;
TRUE;
2 > 2.5;
FALSE;
3 > 2.5;
TRUE;
1 > NULL;
NULL;
1 <= NULL;
NULL;
1 IS NULL;
FALSE;
NULL IS NULL;
TRUE;
NULL IS NOT NULL;
FALSE;
1 IS NOT NULL;
TRUE;
date '1998-12-01' - interval '90' day;
CAST('1998-09-02' AS DATE);
date '1998-12-01' + interval '1' week;
CAST('1998-12-08' AS DATE);
interval '1' year + date '1998-01-01';
CAST('1999-01-01' AS DATE);
interval '1' year + date '1998-01-01' + 3 * 7 * 4;
CAST('1999-01-01' AS DATE) + 84;
date '1998-12-01' - interval '90' foo;
CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show more